diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 62c2974ec..95ce39cbe 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -1,4 +1,5 @@ import logging +import urllib.parse from datetime import timedelta from types import TracebackType from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, @@ -17,6 +18,25 @@ logger = logging.getLogger('databricks.sdk') +def _fix_host_if_needed(host: Optional[str]) -> Optional[str]: + if not host: + return host + + # Add a default scheme if it's missing + if '://' not in host: + host = 'https://' + host + + o = urllib.parse.urlparse(host) + # remove trailing slash + path = o.path.rstrip('/') + # remove port if 443 + netloc = o.netloc + if o.port == 443: + netloc = netloc.split(':')[0] + + return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + + class _BaseClient: def __init__(self, diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 5cae1b2b4..b4efdf603 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -10,11 +10,14 @@ import requests from . import useragent +from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints, Token +from .oauth import (OidcEndpoints, Token, get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints) logger = logging.getLogger('databricks.sdk') @@ -254,24 +257,10 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]: if not self.host: return None if self.is_azure and self.azure_client_id: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) + return get_azure_entra_id_workspace_endpoints(self.host) if self.is_account_client and self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) + return get_account_endpoints(self.host, self.account_id) + return get_workspace_endpoints(self.host) def debug_string(self) -> str: """ Returns log-friendly representation of configured attributes """ @@ -346,22 +335,9 @@ def attributes(cls) -> Iterable[ConfigAttribute]: return cls._attributes def _fix_host_if_needed(self): - if not self.host: - return - - # Add a default scheme if it's missing - if '://' not in self.host: - self.host = 'https://' + self.host - - o = urllib.parse.urlparse(self.host) - # remove trailing slash - path = o.path.rstrip('/') - # remove port if 443 - netloc = o.netloc - if o.port == 443: - netloc = netloc.split(':')[0] - - self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + updated_host = _fix_host_if_needed(self.host) + if updated_host: + self.host = updated_host def load_azure_tenant_id(self): """[Internal] Load the Azure tenant ID from the Azure Databricks login page. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 232465dab..a79151b5a 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -187,30 +187,35 @@ def token() -> Token: def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: if cfg.auth_type != 'external-browser': return None + client_id, client_secret = None, None if cfg.client_id: client_id = cfg.client_id - elif cfg.is_aws: + client_secret = cfg.client_secret + elif cfg.azure_client_id: + client_id = cfg.azure_client + client_secret = cfg.azure_client_secret + + if not client_id: client_id = 'databricks-cli' - elif cfg.is_azure: - # Use Azure AD app for cases when Azure CLI is not available on the machine. - # App has to be registered as Single-page multi-tenant to support PKCE - # TODO: temporary app ID, change it later. - client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' - else: - raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) + oidc_endpoints = cfg.oidc_endpoints + redirect_url = 'http://localhost:8020' + token_cache = TokenCache(host=cfg.host, + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) credentials = token_cache.load() if credentials: # Force a refresh in case the loaded credentials are expired. credentials.token() else: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + redirect_url=redirect_url, + client_secret=client_secret) consent = oauth_client.initiate_consent() if not consent: return None diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e9a3afb90..6cac45afc 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -17,6 +17,8 @@ import requests import requests.auth +from ._base_client import _BaseClient, _fix_host_if_needed + # Error code for PKCE flow in Azure Active Directory, that gets additional retry. # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' @@ -46,8 +48,24 @@ def __call__(self, r): @dataclass class OidcEndpoints: + """ + The endpoints used for OAuth-based authentication in Databricks. + """ + authorization_endpoint: str # ../v1/authorize + """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for + the user to login and authorize the client for user-to-machine (U2M) flows.""" + token_endpoint: str # ../v1/token + """The token endpoint for the OAuth flow.""" + + @staticmethod + def from_dict(d: dict) -> 'OidcEndpoints': + return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'), + token_endpoint=d.get('token_endpoint')) + + def as_dict(self) -> dict: + return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint} @dataclass @@ -220,18 +238,76 @@ def do_GET(self): self.wfile.write(b'You can close this tab.') +def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given account. + :param host: The Databricks account host. + :param account_id: The account ID. + :return: The account's OIDC endpoints. + """ + host = _fix_host_if_needed(host) + oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given workspace. + :param host: The Databricks workspace host. + :return: The workspace's OIDC endpoints. + """ + host = _fix_host_if_needed(host) + oidc = f'{host}/oidc/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]: + """ + Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks + using an application registered in Azure Entra ID. + :param host: The Databricks workspace host. + :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. + """ + # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint + host = _fix_host_if_needed(host) + res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + + class SessionCredentials(Refreshable): - def __init__(self, client: 'OAuthClient', token: Token): - self._client = client + def __init__(self, + token: Token, + token_endpoint: str, + client_id: str, + client_secret: str = None, + redirect_url: str = None): + self._token_endpoint = token_endpoint + self._client_id = client_id + self._client_secret = client_secret + self._redirect_url = redirect_url super().__init__(token) def as_dict(self) -> dict: return {'token': self._token.as_dict()} @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials': - return SessionCredentials(client=client, token=Token.from_dict(raw['token'])) + def from_dict(raw: dict, + token_endpoint: str, + client_id: str, + client_secret: str = None, + redirect_url: str = None) -> 'SessionCredentials': + return SessionCredentials(token=Token.from_dict(raw['token']), + token_endpoint=token_endpoint, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) def auth_type(self): """Implementing CredentialsProvider protocol""" @@ -252,13 +328,13 @@ def refresh(self) -> Token: raise ValueError('oauth2: token expired and refresh token is not set') params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token} headers = {} - if 'microsoft' in self._client.token_url: + if 'microsoft' in self._token_endpoint: # Tokens issued for the 'Single-Page Application' client-type may # only be redeemed via cross-origin requests - headers = {'Origin': self._client.redirect_url} - return retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + headers = {'Origin': self._redirect_url} + return retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._token_endpoint, params=params, use_params=True, headers=headers) @@ -266,27 +342,53 @@ def refresh(self) -> Token: class Consent: - def __init__(self, client: 'OAuthClient', state: str, verifier: str, auth_url: str = None) -> None: - self.auth_url = auth_url - + def __init__(self, + state: str, + verifier: str, + authorization_url: str, + redirect_url: str, + token_endpoint: str, + client_id: str, + client_secret: str = None) -> None: self._verifier = verifier self._state = state - self._client = client + self._authorization_url = authorization_url + self._redirect_url = redirect_url + self._token_endpoint = token_endpoint + self._client_id = client_id + self._client_secret = client_secret def as_dict(self) -> dict: - return {'state': self._state, 'verifier': self._verifier} + return { + 'state': self._state, + 'verifier': self._verifier, + 'authorization_url': self._authorization_url, + 'redirect_url': self._redirect_url, + 'token_endpoint': self._token_endpoint, + 'client_id': self._client_id, + } + + @property + def authorization_url(self) -> str: + return self._authorization_url @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent': - return Consent(client, raw['state'], raw['verifier']) + def from_dict(raw: dict, client_secret: str = None) -> 'Consent': + return Consent(raw['state'], + raw['verifier'], + authorization_url=raw['authorization_url'], + redirect_url=raw['redirect_url'], + token_endpoint=raw['token_endpoint'], + client_id=raw['client_id'], + client_secret=client_secret) def launch_external_browser(self) -> SessionCredentials: - redirect_url = urllib.parse.urlparse(self._client.redirect_url) + redirect_url = urllib.parse.urlparse(self._redirect_url) if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') feedback = [] - logger.info(f'Opening {self.auth_url} in a browser') - webbrowser.open_new(self.auth_url) + logger.info(f'Opening {self._authorization_url} in a browser') + webbrowser.open_new(self._authorization_url) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) with HTTPServer(("localhost", port), handler_factory) as httpd: @@ -308,7 +410,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: if self._state != state: raise ValueError('state mismatch') params = { - 'redirect_uri': self._client.redirect_url, + 'redirect_uri': self._redirect_url, 'grant_type': 'authorization_code', 'code_verifier': self._verifier, 'code': code @@ -316,19 +418,20 @@ def exchange(self, code: str, state: str) -> SessionCredentials: headers = {} while True: try: - token = retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + token = retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._token_endpoint, params=params, headers=headers, use_params=True) - return SessionCredentials(self._client, token) + return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret, + self._redirect_url) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): # Retry in cases of 'Single-Page Application' client-type with # 'Origin' header equal to client's redirect URL. - headers['Origin'] = self._client.redirect_url - msg = f'Retrying OAuth token exchange with {self._client.redirect_url} origin' + headers['Origin'] = self._redirect_url + msg = f'Retrying OAuth token exchange with {self._redirect_url} origin' logger.debug(msg) continue raise e @@ -354,13 +457,28 @@ class OAuthClient: """ def __init__(self, - host: str, - client_id: str, + oidc_endpoints: OidcEndpoints, redirect_url: str, - *, + client_id: str, scopes: List[str] = None, client_secret: str = None): - # TODO: is it a circular dependency?.. + + if not scopes: + scopes = ['all-apis'] + + self.redirect_url = redirect_url + self._client_id = client_id + self._client_secret = client_secret + self._oidc_endpoints = oidc_endpoints + self._scopes = scopes + + @staticmethod + def from_host(host: str, + client_id: str, + redirect_url: str, + *, + scopes: List[str] = None, + client_secret: str = None) -> 'OAuthClient': from .core import Config from .credentials_provider import credentials_strategy @@ -374,18 +492,7 @@ def noop_credentials(_: any): oidc = config.oidc_endpoints if not oidc: raise ValueError(f'{host} does not support OAuth') - - self.host = host - self.redirect_url = redirect_url - self.client_id = client_id - self.client_secret = client_secret - self.token_url = oidc.token_endpoint - self.is_aws = config.is_aws - self.is_azure = config.is_azure - self.is_gcp = config.is_gcp - - self._auth_url = oidc.authorization_endpoint - self._scopes = scopes + return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret) def initiate_consent(self) -> Consent: state = secrets.token_urlsafe(16) @@ -397,18 +504,24 @@ def initiate_consent(self) -> Consent: params = { 'response_type': 'code', - 'client_id': self.client_id, + 'client_id': self._client_id, 'redirect_uri': self.redirect_url, 'scope': ' '.join(self._scopes), 'state': state, 'code_challenge': challenge, 'code_challenge_method': 'S256' } - url = f'{self._auth_url}?{urllib.parse.urlencode(params)}' - return Consent(self, state, verifier, auth_url=url) + auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' + return Consent(state, + verifier, + authorization_url=auth_url, + redirect_url=self.redirect_url, + token_endpoint=self._oidc_endpoints.token_endpoint, + client_id=self._client_id, + client_secret=self._client_secret) def __repr__(self) -> str: - return f'' + return f'' @dataclass @@ -448,17 +561,28 @@ def refresh(self) -> Token: use_header=self.use_header) -class TokenCache(): +class TokenCache: BASE_PATH = "~/.config/databricks-sdk-py/oauth" - def __init__(self, client: OAuthClient) -> None: - self.client = client + def __init__(self, + host: str, + oidc_endpoints: OidcEndpoints, + client_id: str, + redirect_url: str = None, + client_secret: str = None, + scopes: List[str] = None) -> None: + self._host = host + self._client_id = client_id + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_secret = client_secret + self._scopes = scopes or [] @property def filename(self) -> str: # Include host, client_id, and scopes in the cache filename to make it unique. hash = hashlib.sha256() - for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]: + for chunk in [self._host, self._client_id, ",".join(self._scopes), ]: hash.update(chunk.encode('utf-8')) return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) @@ -472,7 +596,11 @@ def load(self) -> Optional[SessionCredentials]: try: with open(self.filename, 'r') as f: raw = json.load(f) - return SessionCredentials.from_dict(self.client, raw) + return SessionCredentials.from_dict(raw, + token_endpoint=self._oidc_endpoints.token_endpoint, + client_id=self._client_id, + client_secret=self._client_secret, + redirect_url=self._redirect_url) except Exception: return None diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py new file mode 100644 index 000000000..061ff60c7 --- /dev/null +++ b/examples/external_browser_auth.py @@ -0,0 +1,72 @@ +from databricks.sdk import WorkspaceClient +import argparse +import logging + +logging.basicConfig(level=logging.DEBUG) + + +def register_custom_app(confidential: bool) -> tuple[str, str]: + """Creates new Custom OAuth App in Databricks Account""" + logging.info("No OAuth custom app client/secret provided, creating new app") + + from databricks.sdk import AccountClient + + account_client = AccountClient() + + custom_app = account_client.custom_app_integration.create( + name="external-browser-demo", + redirect_urls=[ + f"http://localhost:8020", + ], + confidential=confidential, + scopes=["all-apis"], + ) + logging.info(f"Created new custom app: " + f"--client_id {custom_app.client_id} " + f"{'--client_secret ' + custom_app.client_secret if confidential else ''}") + + return custom_app.client_id, custom_app.client_secret + + +def delete_custom_app(client_id: str): + """Creates new Custom OAuth App in Databricks Account""" + logging.info(f"Deleting custom app {client_id}") + from databricks.sdk import AccountClient + account_client = AccountClient() + account_client.custom_app_integration.delete(client_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", help="Databricks host", required=True) + parser.add_argument("--client_id", help="Databricks client_id", default=None) + parser.add_argument("--azure_client_id", help="Databricks azure_client_id", default=None) + parser.add_argument("--client_secret", help="Databricks client_secret", default=None) + parser.add_argument("--azure_client_secret", help="Databricks azure_client_secret", default=None) + parser.add_argument("--register-custom-app", action="store_true", help="Register a new custom app") + parser.add_argument("--register-custom-app-confidential", action="store_true", help="Register a new custom app") + namespace = parser.parse_args() + if namespace.register_custom_app and (namespace.client_id is not None or namespace.azure_client_id is not None): + raise ValueError("Cannot register custom app and provide --client_id/--azure_client_id at the same time") + if not namespace.register_custom_app and namespace.client_id is None and namespace.azure_client_secret is None: + raise ValueError("Must provide --client_id/--azure_client_id or register a custom app") + if namespace.register_custom_app: + client_id, client_secret = register_custom_app(namespace.register_custom_app_confidential) + else: + client_id, client_secret = namespace.client_id, namespace.client_secret + + w = WorkspaceClient( + host=namespace.host, + client_id=client_id, + client_secret=client_secret, + azure_client_id=namespace.azure_client_id, + azure_client_secret=namespace.azure_client_secret, + auth_type="external-browser", + ) + me = w.current_user.me() + print(me) + + if namespace.register_custom_app: + delete_custom_app(client_id) + + diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index 4128de5ca..7c18eadc7 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -31,20 +31,21 @@ import logging import sys -from databricks.sdk.oauth import OAuthClient +from databricks.sdk.oauth import OAuthClient, get_workspace_endpoints +from databricks.sdk.service.compute import ListClustersFilterBy, State APP_NAME = "flask-demo" all_clusters_template = """""" -def create_flask_app(oauth_client: OAuthClient): +def create_flask_app(workspace_host: str, client_id: str, client_secret: str): """The create_flask_app function creates a Flask app that is enabled with OAuth. It initializes the app and web session secret keys with a randomly generated token. It defines two routes for @@ -64,7 +65,7 @@ def callback(): the callback parameters, and redirects the user to the index page.""" from databricks.sdk.oauth import Consent - consent = Consent.from_dict(oauth_client, session["consent"]) + consent = Consent.from_dict(session["consent"], client_secret=client_secret) session["creds"] = consent.exchange_callback_parameters(request.args).as_dict() return redirect(url_for("index")) @@ -72,21 +73,34 @@ def callback(): def index(): """The index page checks if the user has already authenticated and retrieves the user's credentials using the Databricks SDK WorkspaceClient. It then renders the template with the clusters' list.""" + oidc_endpoints = get_workspace_endpoints(workspace_host) + port = request.environ.get("SERVER_PORT") + redirect_url=f"http://localhost:{port}/callback" if "creds" not in session: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) consent = oauth_client.initiate_consent() session["consent"] = consent.as_dict() - return redirect(consent.auth_url) + return redirect(consent.authorization_url) from databricks.sdk import WorkspaceClient from databricks.sdk.oauth import SessionCredentials - credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"]) - workspace_client = WorkspaceClient(host=oauth_client.host, + credentials_strategy = SessionCredentials.from_dict(session["creds"], + token_endpoint=oidc_endpoints.token_endpoint, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) + workspace_client = WorkspaceClient(host=workspace_host, product=APP_NAME, - credentials_provider=credentials_provider, + credentials_strategy=credentials_strategy, ) - - return render_template_string(all_clusters_template, w=workspace_client) + clusters = workspace_client.clusters.list( + filter_by=ListClustersFilterBy(cluster_states=[State.RUNNING, State.PENDING]) + ) + return render_template_string(all_clusters_template, workspace_host=workspace_host, clusters=clusters) return app @@ -100,7 +114,11 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: account_client = AccountClient(profile=args.profile) custom_app = account_client.custom_app_integration.create( - name=APP_NAME, redirect_urls=[f"http://localhost:{args.port}/callback"], confidential=True, + name=APP_NAME, + redirect_urls=[ + f"http://localhost:{args.port}/callback", + ], + confidential=True, scopes=["all-apis"], ) logging.info(f"Created new custom app: " @@ -110,22 +128,6 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: return custom_app.client_id, custom_app.client_secret -def init_oauth_config(args) -> OAuthClient: - """Creates Databricks SDK configuration for OAuth""" - oauth_client = OAuthClient(host=args.host, - client_id=args.client_id, - client_secret=args.client_secret, - redirect_url=f"http://localhost:{args.port}/callback", - scopes=["all-apis"], - ) - if not oauth_client.client_id: - client_id, client_secret = register_custom_app(args) - oauth_client.client_id = client_id - oauth_client.client_secret = client_secret - - return oauth_client - - def parse_arguments() -> argparse.Namespace: """Parses arguments for this demo""" parser = argparse.ArgumentParser(prog=APP_NAME, description=__doc__.strip()) @@ -145,8 +147,10 @@ def parse_arguments() -> argparse.Namespace: logging.getLogger("databricks.sdk").setLevel(logging.DEBUG) args = parse_arguments() - oauth_cfg = init_oauth_config(args) - app = create_flask_app(oauth_cfg) + client_id, client_secret = args.client_id, args.client_secret + if not client_id: + client_id, client_secret = register_custom_app(args) + app = create_flask_app(args.host, client_id, client_secret) app.run( host="localhost", diff --git a/tests/test_oauth.py b/tests/test_oauth.py index ce2d514ff..a637a5508 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,29 +1,126 @@ -from databricks.sdk.core import Config -from databricks.sdk.oauth import OAuthClient, OidcEndpoints, TokenCache - - -def test_token_cache_unique_filename_by_host(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(host="http://localhost:", **common_args) - c2 = OAuthClient(host="https://bar.cloud.databricks.com", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_client_id(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", redirect_url="http://localhost:8020") - c1 = OAuthClient(client_id="abc", **common_args) - c2 = OAuthClient(client_id="def", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_scopes(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(scopes=["foo"], **common_args) - c2 = OAuthClient(scopes=["bar"], **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename +from databricks.sdk._base_client import _BaseClient +from databricks.sdk.oauth import (OidcEndpoints, TokenCache, + get_account_endpoints, + get_workspace_endpoints) + +from .clock import FakeClock + + +def test_token_cache_unique_filename_by_host(): + common_args = dict(client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(host="http://localhost:", + **common_args).filename != TokenCache("https://bar.cloud.databricks.com", + **common_args).filename + + +def test_token_cache_unique_filename_by_client_id(): + common_args = dict(host="http://localhost:", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", + **common_args).filename + + +def test_token_cache_unique_filename_by_scopes(): + common_args = dict(host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], + **common_args).filename + + +def test_account_oidc_endpoints(requests_mock): + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_account_oidc_endpoints_retry_on_429(requests_mock): + # It doesn't seem possible to use requests_mock to return different responses for the same request, e.g. when + # simulating a transient failure. Instead, the nth_request matcher increments a test-wide counter and only matches + # the nth request. + request_count = 0 + + def nth_request(n): + + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + + return observe_request + + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_workspace_oidc_endpoints(requests_mock): + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + + +def test_workspace_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + + return observe_request + + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token")