Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make u2m authentication work with new CLI #150

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,64 @@ def inner() -> Dict[str, str]:
return inner


class BricksCliTokenSource(CliTokenSource):
""" Obtain the token granted by `bricks auth login` CLI command """
class DatabricksCliTokenSource(CliTokenSource):
""" Obtain the token granted by `databricks auth login` CLI command """

def __init__(self, cfg: 'Config'):
cli_path = cfg.bricks_cli_path
if not cli_path:
cli_path = 'bricks'
cmd = [cli_path, 'auth', 'token', '--host', cfg.host]
args = ['auth', 'token', '--host', cfg.host]
if cfg.is_account_client:
cmd += ['--account-id', cfg.account_id]
super().__init__(cmd=cmd,
args += ['--account-id', cfg.account_id]

cli_path = cfg.databricks_cli_path
if not cli_path:
cli_path = 'databricks'

# If the path is unqualified, look it up in PATH.
if cli_path.count("/") == 0:
cli_path = self.__class__._find_executable(cli_path)

super().__init__(cmd=[cli_path, *args],
token_type_field='token_type',
access_token_field='access_token',
expiry_field='expiry')

@staticmethod
def _find_executable(name) -> str:
err = FileNotFoundError("Most likely the Databricks CLI is not installed")
for dir in os.getenv("PATH", default="").split(os.path.pathsep):
path = pathlib.Path(dir).joinpath(name).resolve()
if not path.is_file():
continue

# The new Databricks CLI is a single binary with size > 1MB.
# We use the size as a signal to determine which Databricks CLI is installed.
stat = path.stat()
if stat.st_size < (1024 * 1024):
err = FileNotFoundError("Databricks CLI version <0.100.0 detected")
continue

return str(path)

raise err

@credentials_provider('bricks-cli', ['host', 'is_aws'])
def bricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
token_source = BricksCliTokenSource(cfg)

@credentials_provider('databricks-cli', ['host', 'is_aws'])
def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
try:
token_source.token()
except FileNotFoundError:
logger.debug(f'Most likely Bricks CLI is not installed.')
token_source = DatabricksCliTokenSource(cfg)
except FileNotFoundError as e:
logger.debug(e)
return None

try:
token_source.token()
except IOError as e:
if 'databricks OAuth is not' in str(e):
logger.debug(f'OAuth not configured or not available: {e}')
return None
raise e

logger.info("Using Bricks CLI authentication")
logger.info("Using Databricks CLI authentication")

def inner() -> Dict[str, str]:
token = token_source.token()
Expand Down Expand Up @@ -375,7 +402,7 @@ def auth_type(self) -> str:
def __call__(self, cfg: 'Config') -> HeaderFactory:
auth_providers = [
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
azure_cli, external_browser, bricks_cli, runtime_native_auth
azure_cli, external_browser, databricks_cli, runtime_native_auth
]
for provider in auth_providers:
auth_type = provider.auth_type()
Expand Down Expand Up @@ -438,7 +465,7 @@ class Config:
azure_tenant_id = ConfigAttribute(env='ARM_TENANT_ID', auth='azure')
azure_environment = ConfigAttribute(env='ARM_ENVIRONMENT')
azure_login_app_id = ConfigAttribute(env='DATABRICKS_AZURE_LOGIN_APP_ID', auth='azure')
bricks_cli_path = ConfigAttribute(env='BRICKS_CLI_PATH')
databricks_cli_path = ConfigAttribute(env='DATABRICKS_CLI_PATH')
auth_type = ConfigAttribute(env='DATABRICKS_AUTH_TYPE')
cluster_id = ConfigAttribute(env='DATABRICKS_CLUSTER_ID')
warehouse_id = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID')
Expand Down
117 changes: 116 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from databricks.sdk.core import Config
import os
import pathlib
import random
import string

import pytest

from databricks.sdk.core import (Config, DatabricksCliTokenSource,
databricks_cli)


def test_parse_dsn():
Expand All @@ -8,3 +16,110 @@ def test_parse_dsn():

assert headers['Authorization'] == 'Basic dXNlcjpwYXNz'
assert 'basic' == cfg.auth_type


def test_databricks_cli_token_source_relative_path(config):
config.databricks_cli_path = "./relative/path/to/cli"
ts = DatabricksCliTokenSource(config)
assert ts._cmd[0] == config.databricks_cli_path


def test_databricks_cli_token_source_absolute_path(config):
config.databricks_cli_path = "/absolute/path/to/cli"
ts = DatabricksCliTokenSource(config)
assert ts._cmd[0] == config.databricks_cli_path


def test_databricks_cli_token_source_not_installed(config, monkeypatch):
monkeypatch.setenv('PATH', 'whatever')
with pytest.raises(FileNotFoundError, match="not installed"):
DatabricksCliTokenSource(config)


def write_small_dummy_executable(path: pathlib.Path):
cli = path.joinpath('databricks')
cli.write_text('#!/bin/sh\necho "hello world"\n')
cli.chmod(0o755)
assert cli.stat().st_size < 1024
return cli


def write_large_dummy_executable(path: pathlib.Path):
cli = path.joinpath('databricks')

# Generate a long random string to inflate the file size.
random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024 * 1024))
cli.write_text("""#!/bin/sh
cat <<EOF
{
"access_token": "token",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
}
EOF
exit 0
""" + random_string)
cli.chmod(0o755)
assert cli.stat().st_size >= (1024 * 1024)
return cli


def test_databricks_cli_token_source_installed_legacy(config, monkeypatch, tmp_path):
write_small_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"):
DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_legacy_with_symlink(config, monkeypatch, tmp_path):
dir1 = tmp_path.joinpath('dir1')
dir2 = tmp_path.joinpath('dir2')
dir1.mkdir()
dir2.mkdir()

(dir1 / "databricks").symlink_to(write_small_dummy_executable(dir2))

monkeypatch.setenv('PATH', dir1.as_posix())
with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"):
DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_new(config, monkeypatch, tmp_path):
write_large_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_both(config, monkeypatch, tmp_path):
dir1 = tmp_path.joinpath('dir1')
dir2 = tmp_path.joinpath('dir2')
dir1.mkdir()
dir2.mkdir()

write_small_dummy_executable(dir1)
write_large_dummy_executable(dir2)

# Resolve small before large.
monkeypatch.setenv('PATH', str(os.pathsep).join([dir1.as_posix(), dir2.as_posix()]))
DatabricksCliTokenSource(config)

# Resolve large before small.
monkeypatch.setenv('PATH', str(os.pathsep).join([dir2.as_posix(), dir1.as_posix()]))
DatabricksCliTokenSource(config)


def test_databricks_cli_credential_provider_not_installed(config, monkeypatch):
monkeypatch.setenv('PATH', 'whatever')
assert databricks_cli(config) == None


def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch, tmp_path):
write_small_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
assert databricks_cli(config) == None


def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path):
write_large_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', str(os.pathsep).join([tmp_path.as_posix(), os.environ['PATH']]))
assert databricks_cli(config) is not None