Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added w.config.account_host to get the relevant account host from a workspace client #390

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 80 additions & 13 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import subprocess
import sys
import urllib.parse
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from json import JSONDecodeError
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand Down Expand Up @@ -577,6 +579,59 @@ def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"


class Cloud(Enum):
UNSPECIFIED = None

AWS = 'aws'
AZURE = 'azure'
GCP = 'gcp'


@dataclass
class DatabricksEnvironment:
cloud: Cloud

# zones are not very secret: https://crt.sh/?q=databricks
dns_zone: str

azure_environment: Optional[AzureEnvironment] = None

# The application (client) ID isn't a secret. See
# https://learn.microsoft.com/en-us/entra/identity-platform/developer-glossary#application-client-id
azure_application_id: str = None

def deployment(self, deployment_name: str) -> str:
return f'https://{deployment_name}{self.dns_zone}'


_DATABRICKS_ENVIRONMENTS = [
DatabricksEnvironment(Cloud.UNSPECIFIED, 'localhost'),
DatabricksEnvironment(Cloud.AWS, '.dev.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.staging.cloud.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.cloud.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.cloud.databricks.us'),
DatabricksEnvironment(Cloud.AZURE,
'.dev.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id='62a912ac-b58e-4c1d-89ea-b2dbfc7358fc'),
DatabricksEnvironment(Cloud.AZURE,
'.staging.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id='4a67d088-db5c-48f1-9ff2-0aace800ae68'),
DatabricksEnvironment(Cloud.AZURE,
'.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id=ARM_DATABRICKS_RESOURCE_ID),
DatabricksEnvironment(Cloud.AZURE,
'.databricks.azure.us',
azure_environment=ENVIRONMENTS['USGOVERNMENT'],
azure_application_id=ARM_DATABRICKS_RESOURCE_ID),
DatabricksEnvironment(Cloud.GCP, '.dev.gcp.databricks.com'),
DatabricksEnvironment(Cloud.GCP, '.staging.gcp.databricks.com'),
DatabricksEnvironment(Cloud.GCP, '.gcp.databricks.com'),
]


class Config:
host = ConfigAttribute(env='DATABRICKS_HOST')
account_id = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID')
Expand Down Expand Up @@ -667,16 +722,12 @@ def as_dict(self) -> dict:
@property
def is_azure(self) -> bool:
has_resource_id = self.azure_workspace_resource_id is not None
has_host = self.host is not None
is_public_cloud = has_host and ".azuredatabricks.net" in self.host
is_china_cloud = has_host and ".databricks.azure.cn" in self.host
is_gov_cloud = has_host and ".databricks.azure.us" in self.host
is_valid_cloud = is_public_cloud or is_china_cloud or is_gov_cloud
return has_resource_id or (has_host and is_valid_cloud)
azure_environment = self.environment.azure_environment is not None
return has_resource_id or azure_environment

@property
def is_gcp(self) -> bool:
return self.host and ".gcp.databricks.com" in self.host
return self.host and self.environment.cloud == Cloud.GCP

@property
def is_aws(self) -> bool:
Expand All @@ -688,20 +739,36 @@ def is_account_client(self) -> bool:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")

@property
def account_host(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work for dev/staging, right? We occasionally get internal issues where the SDK doesn't work well with dev/staging.

Also, can we label this somehow as experimental? In the future, if we introduce a workspace-level metadata service, the accounts endpoint can be exposed there, and we can remove this property.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you suggest on the pattern here to make it work for dev/staging?

btw, once we introduce workspace-level metadata service, we should just print warnings from within this call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea: have a list to lookup cloud/env/account host name based on domain suffix, from most-specific to least-specific:

    # Most specific to least specific
    [
        ('.dev.databricks.com', 'AWS', 'DEV', 'PUBLIC', 'accounts.dev.databricks.com'),
        ('.staging.cloud.databricks.com', 'AWS', 'STAGING', 'PUBLIC', 'accounts.staging.cloud.databricks.com'),
        ('.cloud.databricks.com', 'AWS', 'PROD', 'PUBLIC', 'accounts.cloud.databricks.com'),
        ('.cloud.databricks.us', 'AWS', 'PROD', 'GOV', 'accounts.cloud.databricks.us'),
        ('.dev.azuredatabricks.net', 'AZURE', 'DEV', 'PUBLIC', 'accounts.dev.azuredatabricks.net')
        ('.staging.azuredatabricks.net', 'AZURE', 'STAGING', 'PUBLIC', 'accounts.staging.azuredatabricks.net'),
        ('.azuredatabricks.net', 'AZURE', 'PROD', 'PUBLIC', 'accounts.azuredatabricks.net'),
        ('.databricks.azure.us', 'AZURE', 'PROD', 'GOV', 'accounts.databricks.azure.us'),
        ('.dev.gcp.databricks.com', 'GCP', 'DEV', 'PUBLIC', 'accounts.dev.gcp.databricks.com'),
        ('.staging.gcp.databricks.com', 'GCP', 'STAGING', 'PUBLIC', 'accounts.staging.gcp.databricks.com'),
        ('.gcp.databricks.com', 'GCP', 'PROD', 'PUBLIC', 'accounts.gcp.databricks.com'),
    ]

The simpler version is to just replace the most specific zone with accounts, which works everywhere today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgyucht do we want to expose dev/staging?... if we do - then why don't we expose the azure_login_app_id along with it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely fine by me!

if self.is_account_client:
return self.host
return self.environment.deployment('accounts')

@property
def arm_environment(self) -> AzureEnvironment:
env = self.azure_environment if self.azure_environment else "PUBLIC"
try:
return ENVIRONMENTS[env]
except KeyError:
raise ValueError(f"Cannot find Azure {env} Environment")
return self.environment.azure_environment

@property
def effective_azure_login_app_id(self):
app_id = self.azure_login_app_id
if app_id:
return app_id
return ARM_DATABRICKS_RESOURCE_ID
return self.environment.azure_application_id

@property
def environment(self) -> DatabricksEnvironment:
hostname = self.hostname
if ':' in hostname:
# special case for unit tests and OAuth U2M
hostname, _ = hostname.split(':')
if hostname == 'x' or hostname == '127.0.0.1':
# special case for unit tests
hostname = 'localhost'
for env in _DATABRICKS_ENVIRONMENTS:
if hostname.endswith(env.dns_zone):
return env
raise ValueError(f"Cannot find DatabricksEnvironment for {hostname}")

@property
def hostname(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def inner(h: BaseHTTPRequestHandler):
assert len(requests) == 2


@pytest.mark.parametrize(
"host,account_host",
[('https://accounts.cloud.databricks.com', 'https://accounts.cloud.databricks.com'),
('https://dbc-ldflSlsd.cloud.databricks.com', 'https://accounts.cloud.databricks.com'),
('https://abd-23424234234.12.azuredatabricks.net', 'https://accounts.azuredatabricks.net'),
('https://abd-23424234234.12.databricks.azure.us', 'https://accounts.databricks.azure.us'),
('https://23423423.gcp.databricks.com', 'https://accounts.gcp.databricks.com'), ])
def test_get_account_host(host, account_host):
cfg = Config(host=host, token=...)
assert account_host == cfg.account_host


def test_github_oidc_flow_works_with_azure(monkeypatch):

def inner(h: BaseHTTPRequestHandler):
Expand Down
Loading