Skip to content

Commit

Permalink
Make u2m authentication work with new CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
pietern committed Jun 7, 2023
1 parent 97212ef commit d1ae70e
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 18 deletions.
61 changes: 44 additions & 17 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,64 @@ def inner() -> Dict[str, str]:
return inner


class BricksCliTokenSource(CliTokenSource):
""" Obtain the token granted by `bricks auth login` CLI command """
class DatabricksCliTokenSource(CliTokenSource):
""" Obtain the token granted by `databricks auth login` CLI command """

def __init__(self, cfg: 'Config'):
cli_path = cfg.bricks_cli_path
if not cli_path:
cli_path = 'bricks'
cmd = [cli_path, 'auth', 'token', '--host', cfg.host]
args = ['auth', 'token', '--host', cfg.host]
if cfg.is_account_client:
cmd += ['--account-id', cfg.account_id]
super().__init__(cmd=cmd,
args += ['--account-id', cfg.account_id]

cli_path = cfg.databricks_cli_path
if not cli_path:
cli_path = 'databricks'

# If the path is unqualified, look it up in PATH.
if cli_path.count("/") == 0:
cli_path = self.__class__._find_executable(cli_path)

super().__init__(cmd=[cli_path, *args],
token_type_field='token_type',
access_token_field='access_token',
expiry_field='expiry')

@staticmethod
def _find_executable(name) -> str:
err = FileNotFoundError("Most likely the Databricks CLI is not installed")
for dir in os.getenv("PATH", default="").split(os.path.pathsep):
path = pathlib.Path(dir).joinpath(name).resolve()
if not path.is_file():
continue

# The new Databricks CLI is a single binary with size > 1MB.
# We use the size as a signal to determine which Databricks CLI is installed.
stat = path.stat()
if stat.st_size < (1024 * 1024):
err = FileNotFoundError("Databricks CLI version <0.100.0 detected")
continue

return str(path)

raise err

@credentials_provider('bricks-cli', ['host', 'is_aws'])
def bricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
token_source = BricksCliTokenSource(cfg)

@credentials_provider('databricks-cli', ['host', 'is_aws'])
def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]:
try:
token_source.token()
except FileNotFoundError:
logger.debug(f'Most likely Bricks CLI is not installed.')
token_source = DatabricksCliTokenSource(cfg)
except FileNotFoundError as e:
logger.debug(e)
return None

try:
token_source.token()
except IOError as e:
if 'databricks OAuth is not' in str(e):
logger.debug(f'OAuth not configured or not available: {e}')
return None
raise e

logger.info("Using Bricks CLI authentication")
logger.info("Using Databricks CLI authentication")

def inner() -> Dict[str, str]:
token = token_source.token()
Expand Down Expand Up @@ -375,7 +402,7 @@ def auth_type(self) -> str:
def __call__(self, cfg: 'Config') -> HeaderFactory:
auth_providers = [
pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal,
azure_cli, external_browser, bricks_cli, runtime_native_auth
azure_cli, external_browser, databricks_cli, runtime_native_auth
]
for provider in auth_providers:
auth_type = provider.auth_type()
Expand Down Expand Up @@ -438,7 +465,7 @@ class Config:
azure_tenant_id = ConfigAttribute(env='ARM_TENANT_ID', auth='azure')
azure_environment = ConfigAttribute(env='ARM_ENVIRONMENT')
azure_login_app_id = ConfigAttribute(env='DATABRICKS_AZURE_LOGIN_APP_ID', auth='azure')
bricks_cli_path = ConfigAttribute(env='BRICKS_CLI_PATH')
databricks_cli_path = ConfigAttribute(env='DATABRICKS_CLI_PATH')
auth_type = ConfigAttribute(env='DATABRICKS_AUTH_TYPE')
cluster_id = ConfigAttribute(env='DATABRICKS_CLUSTER_ID')
warehouse_id = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID')
Expand Down
116 changes: 115 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.sdk.core import Config
import os
import pathlib
import pytest
import random
import string

from databricks.sdk.core import Config, DatabricksCliTokenSource, databricks_cli


def test_parse_dsn():
Expand All @@ -8,3 +14,111 @@ def test_parse_dsn():

assert headers['Authorization'] == 'Basic dXNlcjpwYXNz'
assert 'basic' == cfg.auth_type


def test_databricks_cli_token_source_relative_path(config):
config.databricks_cli_path = "./relative/path/to/cli"
ts = DatabricksCliTokenSource(config)
assert ts._cmd[0] == config.databricks_cli_path


def test_databricks_cli_token_source_absolute_path(config):
config.databricks_cli_path = "/absolute/path/to/cli"
ts = DatabricksCliTokenSource(config)
assert ts._cmd[0] == config.databricks_cli_path


def test_databricks_cli_token_source_not_installed(config, monkeypatch):
monkeypatch.setenv('PATH', 'whatever')
with pytest.raises(FileNotFoundError, match="not installed"):
ts = DatabricksCliTokenSource(config)


def write_small_dummy_executable(path: pathlib.Path):
cli = path.joinpath('databricks')
cli.write_text('#!/bin/sh\necho "hello world"\n')
cli.chmod(0o755)
assert cli.stat().st_size < 1024
return cli


def write_large_dummy_executable(path: pathlib.Path):
cli = path.joinpath('databricks')

# Generate a long random string to inflate the file size.
random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024*1024))
cli.write_text(
"""#!/bin/sh
cat <<EOF
{
"access_token": "token",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
}
EOF
exit 0
""" + random_string)
cli.chmod(0o755)
assert cli.stat().st_size >= (1024*1024)
return cli


def test_databricks_cli_token_source_installed_legacy(config, monkeypatch, tmp_path):
write_small_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"):
ts = DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_legacy_with_symlink(config, monkeypatch, tmp_path):
dir1 = tmp_path.joinpath('dir1')
dir2 = tmp_path.joinpath('dir2')
dir1.mkdir()
dir2.mkdir()

(dir1 / "databricks").symlink_to(write_small_dummy_executable(dir2))

monkeypatch.setenv('PATH', dir1.as_posix())
with pytest.raises(FileNotFoundError, match="version <0.100.0 detected"):
ts = DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_new(config, monkeypatch, tmp_path):
write_large_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
ts = DatabricksCliTokenSource(config)


def test_databricks_cli_token_source_installed_both(config, monkeypatch, tmp_path):
dir1 = tmp_path.joinpath('dir1')
dir2 = tmp_path.joinpath('dir2')
dir1.mkdir()
dir2.mkdir()

write_small_dummy_executable(dir1)
write_large_dummy_executable(dir2)

# Resolve small before large.
monkeypatch.setenv('PATH', str(os.pathsep).join([dir1.as_posix(), dir2.as_posix()]))
ts = DatabricksCliTokenSource(config)

# Resolve large before small.
monkeypatch.setenv('PATH', str(os.pathsep).join([dir2.as_posix(), dir1.as_posix()]))
ts = DatabricksCliTokenSource(config)


def test_databricks_cli_credential_provider_not_installed(config, monkeypatch):
monkeypatch.setenv('PATH', 'whatever')
assert databricks_cli(config) == None


def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch, tmp_path):
write_small_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())
assert databricks_cli(config) == None


def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path):
write_large_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', str(os.pathsep).join([tmp_path.as_posix(), os.environ['PATH']]))
assert databricks_cli(config) is not None

0 comments on commit d1ae70e

Please sign in to comment.