Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message when private link enabled workspaces reject requests #647

Merged
merged 5 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions databricks/sdk/azure.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,9 @@
from dataclasses import dataclass
from typing import Dict

from .oauth import TokenSource
from .service.provisioning import Workspace


@dataclass
class AzureEnvironment:
name: str
service_management_endpoint: str
resource_manager_endpoint: str
active_directory_endpoint: str


ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"

ENVIRONMENTS = dict(
PUBLIC=AzureEnvironment(name="PUBLIC",
service_management_endpoint="https://management.core.windows.net/",
resource_manager_endpoint="https://management.azure.com/",
active_directory_endpoint="https://login.microsoftonline.com/"),
USGOVERNMENT=AzureEnvironment(name="USGOVERNMENT",
service_management_endpoint="https://management.core.usgovcloudapi.net/",
resource_manager_endpoint="https://management.usgovcloudapi.net/",
active_directory_endpoint="https://login.microsoftonline.us/"),
CHINA=AzureEnvironment(name="CHINA",
service_management_endpoint="https://management.core.chinacloudapi.cn/",
resource_manager_endpoint="https://management.chinacloudapi.cn/",
active_directory_endpoint="https://login.chinacloudapi.cn/"),
)


def add_workspace_id_header(cfg: 'Config', headers: Dict[str, str]):
if cfg.azure_workspace_resource_id:
headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id
Expand Down
15 changes: 6 additions & 9 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

import requests

from .azure import AzureEnvironment
from .clock import Clock, RealClock
from .credentials_provider import CredentialsProvider, DefaultCredentials
from .environments import (ALL_ENVS, DEFAULT_ENVIRONMENT, Cloud,
DatabricksEnvironment)
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import OidcEndpoints
from .version import __version__

Expand Down Expand Up @@ -154,11 +153,7 @@ def environment(self) -> DatabricksEnvironment:
"""Returns the environment based on configuration."""
if self.databricks_environment:
return self.databricks_environment
if self.host:
for environment in ALL_ENVS:
if self.host.endswith(environment.dns_zone):
return environment
if self.azure_workspace_resource_id:
if not self.host and self.azure_workspace_resource_id:
azure_env = self._get_azure_environment_name()
for environment in ALL_ENVS:
if environment.cloud != Cloud.AZURE:
Expand All @@ -168,10 +163,12 @@ def environment(self) -> DatabricksEnvironment:
if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"):
continue
return environment
return DEFAULT_ENVIRONMENT
return get_environment_for_hostname(self.host)

@property
def is_azure(self) -> bool:
if self.azure_workspace_resource_id:
return True
return self.environment.cloud == Cloud.AZURE

@property
Expand Down
5 changes: 5 additions & 0 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .errors import DatabricksError, error_mapper
from .errors.private_link import _is_private_link_redirect
from .retries import retried

__all__ = ['Config', 'DatabricksError']
Expand Down Expand Up @@ -239,6 +240,10 @@ def _perform(self,
# See https://stackoverflow.com/a/58821552/277035
payload = response.json()
raise self._make_nicer_error(response=response, **payload) from None
# Private link failures happen via a redirect to the login page. From a requests-perspective, the request
# is successful, but the response is not what we expect. We need to handle this case separately.
if _is_private_link_redirect(response):
raise self._make_nicer_error(response=response) from None
return response
except requests.exceptions.JSONDecodeError:
message = self._make_sense_from_html(response.text)
Expand Down
35 changes: 34 additions & 1 deletion databricks/sdk/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,31 @@
from enum import Enum
from typing import Optional

from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment

@dataclass
class AzureEnvironment:
name: str
service_management_endpoint: str
resource_manager_endpoint: str
active_directory_endpoint: str


ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"

ENVIRONMENTS = dict(
PUBLIC=AzureEnvironment(name="PUBLIC",
service_management_endpoint="https://management.core.windows.net/",
resource_manager_endpoint="https://management.azure.com/",
active_directory_endpoint="https://login.microsoftonline.com/"),
USGOVERNMENT=AzureEnvironment(name="USGOVERNMENT",
service_management_endpoint="https://management.core.usgovcloudapi.net/",
resource_manager_endpoint="https://management.usgovcloudapi.net/",
active_directory_endpoint="https://login.microsoftonline.us/"),
CHINA=AzureEnvironment(name="CHINA",
service_management_endpoint="https://management.core.chinacloudapi.cn/",
resource_manager_endpoint="https://management.chinacloudapi.cn/",
active_directory_endpoint="https://login.chinacloudapi.cn/"),
)


class Cloud(Enum):
Expand Down Expand Up @@ -70,3 +94,12 @@ def azure_active_directory_endpoint(self) -> Optional[str]:
DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"),
DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com")
]


def get_environment_for_hostname(hostname: str) -> DatabricksEnvironment:
if not hostname:
return DEFAULT_ENVIRONMENT
for env in ALL_ENVS:
if hostname.endswith(env.dns_zone):
return env
return DEFAULT_ENVIRONMENT
1 change: 1 addition & 0 deletions databricks/sdk/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import DatabricksError, ErrorDetail
from .mapper import error_mapper
from .platform import *
from .private_link import PrivateLinkValidationError
from .sdk import *
4 changes: 4 additions & 0 deletions databricks/sdk/errors/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from databricks.sdk.errors.base import DatabricksError

from .overrides import _ALL_OVERRIDES
from .private_link import (_get_private_link_validation_error,
_is_private_link_redirect)


def error_mapper(response: requests.Response, raw: dict) -> DatabricksError:
Expand All @@ -21,6 +23,8 @@ def error_mapper(response: requests.Response, raw: dict) -> DatabricksError:
# where there's a default exception class per HTTP status code, and we do
# rely on Databricks platform exception mapper to do the right thing.
return platform.STATUS_CODE_MAPPING[status_code](**raw)
if _is_private_link_redirect(response):
return _get_private_link_validation_error(response.url)

# backwards-compatible error creation for cases like using older versions of
# the SDK on way never releases of the platform.
Expand Down
60 changes: 60 additions & 0 deletions databricks/sdk/errors/private_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import dataclass
from urllib import parse

import requests

from ..environments import Cloud, get_environment_for_hostname
from .platform import PermissionDenied


@dataclass
class _PrivateLinkInfo:
serviceName: str
endpointName: str
referencePage: str

def error_message(self):
return (
f'The requested workspace has {self.serviceName} enabled and is not accessible from the current network. '
f'Ensure that {self.serviceName} is properly configured and that your device has access to the '
f'{self.endpointName}. For more information, see {self.referencePage}.')


_private_link_info_map = {
Cloud.AWS:
_PrivateLinkInfo(serviceName='AWS PrivateLink',
endpointName='AWS VPC endpoint',
referencePage='https://docs.databricks.com/en/security/network/classic/privatelink.html',
),
Cloud.AZURE:
_PrivateLinkInfo(
serviceName='Azure Private Link',
endpointName='Azure Private Link endpoint',
referencePage='https://learn.microsoft.com/en-us/azure/databricks/security/network/classic/private-link-standard#authentication-troubleshooting',
),
Cloud.GCP:
_PrivateLinkInfo(
serviceName='Private Service Connect',
endpointName='GCP VPC endpoint',
referencePage='https://docs.gcp.databricks.com/en/security/network/classic/private-service-connect.html',
)
}


class PrivateLinkValidationError(PermissionDenied):
"""Raised when a user tries to access a Private Link-enabled workspace, but the user's network does not have access
to the workspace."""


def _is_private_link_redirect(resp: requests.Response) -> bool:
parsed = parse.urlparse(resp.url)
return parsed.path == '/login.html' and 'error=private-link-validation-error' in parsed.query


def _get_private_link_validation_error(url: str) -> _PrivateLinkInfo:
parsed = parse.urlparse(url)
env = get_environment_for_hostname(parsed.hostname)
return PrivateLinkValidationError(message=_private_link_info_map[env.cloud].error_message(),
error_code='PRIVATE_LINK_VALIDATION_ERROR',
status_code=403,
)
16 changes: 8 additions & 8 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def test_config_azure_pat():
def test_config_azure_cli_host(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
cfg = Config(host='x', azure_workspace_resource_id='/sub/rg/ws')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
assert cfg.host == 'https://x'
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
assert cfg.is_azure


Expand Down Expand Up @@ -232,32 +232,32 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def
def test_config_azure_cli_host_and_resource_id(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
cfg = Config(host='x', azure_workspace_resource_id='/sub/rg/ws')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
assert cfg.host == 'https://x'
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
assert cfg.is_azure


def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch):
monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost')
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
cfg = Config(host='x', azure_workspace_resource_id='/sub/rg/ws')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
assert cfg.host == 'https://x'
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
assert cfg.is_azure


@raises(
"validate: more than one authorization method configured: azure and basic. Config: host=https://x, username=x, azure_workspace_resource_id=/sub/rg/ws. Env: DATABRICKS_USERNAME"
"validate: more than one authorization method configured: azure and basic. Config: host=https://adb-123.4.azuredatabricks.net, username=x, azure_workspace_resource_id=/sub/rg/ws. Env: DATABRICKS_USERNAME"
)
def test_config_azure_and_password_conflict(monkeypatch):
monkeypatch.setenv('DATABRICKS_USERNAME', 'x')
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
cfg = Config(host='x', azure_workspace_resource_id='/sub/rg/ws')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')


@raises(
Expand Down
20 changes: 15 additions & 5 deletions tests/test_auth_manual_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ def test_azure_cli_workspace_header_present(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-Workspace-Resource-Id' in cfg.authenticate()
assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id

Expand All @@ -16,7 +18,9 @@ def test_azure_cli_user_with_management_access(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


Expand All @@ -25,7 +29,9 @@ def test_azure_cli_user_no_management_access(monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()


Expand All @@ -34,7 +40,9 @@ def test_azure_cli_fallback(monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
monkeypatch.setenv('FAIL_IF', 'subscription')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


Expand All @@ -43,5 +51,7 @@ def test_azure_cli_with_warning_on_stderr(monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
monkeypatch.setenv('WARN', 'this is a warning')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import requests

from databricks.sdk import WorkspaceClient
from databricks.sdk.azure import ENVIRONMENTS, AzureEnvironment
from databricks.sdk.core import (ApiClient, Config, DatabricksError,
StreamingResponse)
from databricks.sdk.credentials_provider import (CliTokenSource,
CredentialsProvider,
DatabricksCliTokenSource,
HeaderFactory, databricks_cli)
from databricks.sdk.environments import Cloud, DatabricksEnvironment
from databricks.sdk.environments import (ENVIRONMENTS, AzureEnvironment, Cloud,
DatabricksEnvironment)
from databricks.sdk.service.catalog import PermissionsChange
from databricks.sdk.service.iam import AccessControlRequest
from databricks.sdk.version import __version__
Expand Down
8 changes: 8 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def test_missing_error_code():
assert errors.DatabricksError == type(err)


def test_private_link_error():
resp = requests.Response()
resp.url = 'https://databricks.com/login.html?error=private-link-validation-error'
resp.request = requests.Request('GET', 'https://databricks.com/api/2.0/service').prepare()
err = errors.error_mapper(resp, {})
assert errors.PrivateLinkValidationError == type(err)


@pytest.mark.parametrize('status_code, error_code, klass',
[(400, ..., errors.BadRequest), (400, 'INVALID_PARAMETER_VALUE', errors.BadRequest),
(400, 'INVALID_PARAMETER_VALUE', errors.InvalidParameterValue),
Expand Down
Loading