Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental command for enabling HMS federation #2939

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions labs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,8 @@ commands:

- name: create-federated-catalog
description: (EXPERIMENTAL) Create a federated catalog in the workspace

- name: enable-hms-federation
description: (EXPERIMENTAL) Enable HMS federation based migration flow. When this is enabled, UCX will create a federated HMS catalog which syncs from the workspace HMS.


5 changes: 5 additions & 0 deletions src/databricks/labs/ucx/aws/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ def __init__(
external_locations: ExternalLocations,
aws_resource_permissions: AWSResourcePermissions,
principal_acl: PrincipalACL,
enable_hms_federation: bool = False,
):
self._ws = ws
self._external_locations = external_locations
self._aws_resource_permissions = aws_resource_permissions
self._principal_acl = principal_acl
# When HMS federation is enabled, the fallback bit is set for all the
# locations which are created by UCX.
self._enable_fallback_mode = enable_hms_federation

def run(self) -> None:
"""
Expand All @@ -52,6 +56,7 @@ def run(self) -> None:
path,
credential_dict[role_arn],
skip_validation=True,
fallback=self._enable_fallback_mode,
)
self._principal_acl.apply_location_acl()

Expand Down
10 changes: 9 additions & 1 deletion src/databricks/labs/ucx/azure/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ def __init__(
resource_permissions: AzureResourcePermissions,
azurerm: AzureResources,
principal_acl: PrincipalACL,
enable_hms_federation: bool = False,
):
self._ws = ws
self._hms_locations = hms_locations
self._resource_permissions = resource_permissions
self._azurerm = azurerm
self._principal_acl = principal_acl
self._enable_fallback_mode = enable_hms_federation

def _app_id_credential_name_mapping(self) -> tuple[dict[str, str], dict[str, str]]:
# list all storage credentials.
Expand Down Expand Up @@ -120,7 +122,13 @@ def _create_external_location_helper(
) -> str | None:
try:
self._ws.external_locations.create(
name, url, credential, comment=comment, read_only=read_only, skip_validation=skip_validation
name,
url,
credential,
comment=comment,
read_only=read_only,
skip_validation=skip_validation,
fallback=self._enable_fallback_mode,
)
return url
except InvalidParameterValue as invalid:
Expand Down
8 changes: 8 additions & 0 deletions src/databricks/labs/ucx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,5 +808,13 @@ def create_federated_catalog(w: WorkspaceClient, _: Prompts):
ctx.federation.register_internal_hms_as_federated_catalog()


@ucx.command
def enable_hms_federation(w: WorkspaceClient, _: Prompts, ctx: WorkspaceContext | None = None):
"""(Experimental) Create federated catalog from current workspace Hive Metastore."""
if not ctx:
ctx = WorkspaceContext(w)
ctx.federation_enabler.enable()


if __name__ == "__main__":
ucx()
2 changes: 2 additions & 0 deletions src/databricks/labs/ucx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
# [INTERNAL ONLY] Whether the assessment should lint only specific dashboards.
include_dashboard_ids: list[str] | None = None

enable_hms_federation: bool = False

managed_table_external_storage: str = 'CLONE'

def replace_inventory_variable(self, text: str) -> str:
Expand Down
15 changes: 13 additions & 2 deletions src/databricks/labs/ucx/contexts/workspace_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from databricks.labs.ucx.azure.locations import ExternalLocationsMigration
from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources
from databricks.labs.ucx.contexts.application import CliContext
from databricks.labs.ucx.hive_metastore.federation import HiveMetastoreFederation
from databricks.labs.ucx.hive_metastore.federation import HiveMetastoreFederation, HiveMetastoreFederationEnabler
from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex
from databricks.labs.ucx.progress.install import ProgressTrackingInstallation
from databricks.labs.ucx.source_code.base import CurrentSessionState
Expand Down Expand Up @@ -118,6 +118,7 @@ def external_locations_migration(self) -> AWSExternalLocationsMigration | Extern
self.external_locations,
self.aws_resource_permissions,
self.principal_acl,
self.config.enable_hms_federation,
)
if self.is_azure:
return ExternalLocationsMigration(
Expand All @@ -126,6 +127,7 @@ def external_locations_migration(self) -> AWSExternalLocationsMigration | Extern
self.azure_resource_permissions,
self.azure_resources,
self.principal_acl,
self.config.enable_hms_federation,
)
raise NotImplementedError

Expand Down Expand Up @@ -188,9 +190,18 @@ def notebook_loader(self) -> NotebookLoader:
def progress_tracking_installation(self) -> ProgressTrackingInstallation:
return ProgressTrackingInstallation(self.sql_backend, self.config.ucx_catalog)

@cached_property
def federation_enabler(self):
return HiveMetastoreFederationEnabler(self.installation)

@cached_property
def federation(self):
return HiveMetastoreFederation(self.workspace_client, self.external_locations, self.workspace_info)
return HiveMetastoreFederation(
self.workspace_client,
self.external_locations,
self.workspace_info,
self.config.enable_hms_federation,
)


class LocalCheckoutContext(WorkspaceContext):
Expand Down
22 changes: 19 additions & 3 deletions src/databricks/labs/ucx/hive_metastore/federation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import logging

from databricks.labs.blueprint.installation import Installation
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import AlreadyExists, NotFound, BadRequest
from databricks.sdk.service.catalog import (
Expand All @@ -13,24 +14,39 @@
)

from databricks.labs.ucx.account.workspaces import WorkspaceInfo
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.hive_metastore import ExternalLocations


logger = logging.getLogger(__name__)


class HiveMetastoreFederationEnabler:
def __init__(self, installation: Installation):
self._installation = installation

def enable(self):
config = self._installation.load(WorkspaceConfig)
config.enable_hms_federation = True
self._installation.save(config)


class HiveMetastoreFederation:
def __init__(
self,
workspace_client: WorkspaceClient,
external_locations: ExternalLocations,
workspace_info: WorkspaceInfo,
enable_hms_federation: bool = False,
):
self._workspace_client = workspace_client
self._external_locations = external_locations
self._workspace_info = workspace_info
self._enable_hms_federation = enable_hms_federation

def register_internal_hms_as_federated_catalog(self) -> CatalogInfo:
if not self._enable_hms_federation:
raise RuntimeWarning('Run `databricks labs ucx enable-hms-federation` to enable HMS Federation')
name = self._workspace_info.current()
connection_info = self._get_or_create_connection(name)
assert connection_info.name is not None
Expand All @@ -52,7 +68,7 @@ def _get_or_create_connection(self, name: str) -> ConnectionInfo:
try:
return self._workspace_client.connections.create(
name=name,
connection_type=ConnectionType.HIVE_METASTORE,
connection_type=ConnectionType.HIVE_METASTORE, # needs SDK change
options={"builtin": "true"},
)
except AlreadyExists:
Expand All @@ -68,7 +84,7 @@ def _get_authorized_paths(self) -> str:
authorized_paths = []
current_user = self._workspace_client.current_user.me()
if not current_user.user_name:
raise ValueError('Current user not found')
raise NotFound('Current user not found')
for external_location_info in self._external_locations.snapshot():
location = external_location_info.location.rstrip('/').replace('s3a://', 's3://')
existing_location = existing.get(location)
Expand All @@ -81,7 +97,7 @@ def _get_authorized_paths(self) -> str:
continue
self._add_missing_permissions_if_needed(location_name, current_user.user_name)
authorized_paths.append(location)
return ",".join(sorted(authorized_paths))
return ",".join(authorized_paths)

def _add_missing_permissions_if_needed(self, location_name: str, current_user: str):
grants = self._location_grants(location_name)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/aws/test_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def test_create_external_locations(mock_ws, installation_multiple_roles, backend
)
external_locations_migration.run()
calls = [
call('bucket1_folder1', 's3://BUCKET1/FOLDER1', 'cred1', skip_validation=True),
call('bucket2_folder2', 's3://BUCKET2/FOLDER2', 'cred1', skip_validation=True),
call('bucketx_folderx', 's3://BUCKETX/FOLDERX', 'credx', skip_validation=True),
call('bucket1_folder1', 's3://BUCKET1/FOLDER1', 'cred1', skip_validation=True, fallback=False),
call('bucket2_folder2', 's3://BUCKET2/FOLDER2', 'cred1', skip_validation=True, fallback=False),
call('bucketx_folderx', 's3://BUCKETX/FOLDERX', 'credx', skip_validation=True, fallback=False),
]
mock_ws.external_locations.create.assert_has_calls(calls, any_order=True)
aws.get_role_policy.assert_not_called()
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_create_external_locations_skip_existing(mock_ws, backend, locations):
)
external_locations_migration.run()
calls = [
call("bucket1_folder1", 's3://BUCKET1/FOLDER1', 'cred1', skip_validation=True),
call("bucket1_folder1", 's3://BUCKET1/FOLDER1', 'cred1', skip_validation=True, fallback=False),
]
mock_ws.external_locations.create.assert_has_calls(calls, any_order=True)
aws.get_role_policy.assert_not_called()
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/azure/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_run_service_principal():
comment="Created by UCX",
read_only=False,
skip_validation=False,
fallback=False,
)
ws.external_locations.create.assert_any_call(
"container2_test",
Expand All @@ -113,6 +114,7 @@ def test_run_service_principal():
comment="Created by UCX",
read_only=True,
skip_validation=False,
fallback=False,
)


Expand Down Expand Up @@ -186,6 +188,7 @@ def test_skip_unsupported_location(caplog):
comment="Created by UCX",
read_only=False,
skip_validation=False,
fallback=False,
)
assert "Skip unsupported location: adl://container2@test.dfs.core.windows.net" in caplog.text
assert "Skip unsupported location: wasbs://container2@test.dfs.core.windows.net" in caplog.text
Expand Down Expand Up @@ -258,6 +261,7 @@ def test_run_managed_identity():
comment="Created by UCX",
read_only=False,
skip_validation=False,
fallback=False,
)
ws.external_locations.create.assert_any_call(
"container5_test_a_b",
Expand All @@ -266,6 +270,7 @@ def test_run_managed_identity():
comment="Created by UCX",
read_only=True,
skip_validation=False,
fallback=False,
)


Expand Down Expand Up @@ -336,6 +341,7 @@ def test_run_access_connectors():
comment="Created by UCX",
read_only=False,
skip_validation=False,
fallback=False,
),
call(
"container5_test_a_b",
Expand All @@ -344,6 +350,7 @@ def test_run_access_connectors():
comment="Created by UCX",
read_only=False,
skip_validation=False,
fallback=False,
),
]
ws.external_locations.create.assert_has_calls(calls)
Expand Down Expand Up @@ -451,6 +458,7 @@ def test_location_failed_to_read():
comment="Created by UCX",
read_only=True,
skip_validation=True,
fallback=False,
)


Expand Down
27 changes: 24 additions & 3 deletions tests/unit/hive_metastore/test_federation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import create_autospec, call

from databricks.labs.blueprint.installation import MockInstallation
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import AlreadyExists
from databricks.sdk.service.catalog import (
Expand All @@ -16,8 +17,9 @@
from databricks.sdk.service.iam import User

from databricks.labs.ucx.account.workspaces import WorkspaceInfo
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.hive_metastore import ExternalLocations
from databricks.labs.ucx.hive_metastore.federation import HiveMetastoreFederation
from databricks.labs.ucx.hive_metastore.federation import HiveMetastoreFederation, HiveMetastoreFederationEnabler
from databricks.labs.ucx.hive_metastore.locations import ExternalLocation


Expand All @@ -42,7 +44,7 @@ def test_create_federated_catalog():
privilege_assignments=[PrivilegeAssignment(privileges=[Privilege.MANAGE], principal='any')]
)

hms_fed = HiveMetastoreFederation(workspace_client, external_locations, workspace_info)
hms_fed = HiveMetastoreFederation(workspace_client, external_locations, workspace_info, enable_hms_federation=True)
hms_fed.register_internal_hms_as_federated_catalog()

workspace_client.connections.create.assert_called_with(
Expand Down Expand Up @@ -94,7 +96,7 @@ def test_already_existing_connection():
privilege_assignments=[PrivilegeAssignment(privileges=[Privilege.MANAGE], principal='any')]
)

hms_fed = HiveMetastoreFederation(workspace_client, external_locations, workspace_info)
hms_fed = HiveMetastoreFederation(workspace_client, external_locations, workspace_info, enable_hms_federation=True)
hms_fed.register_internal_hms_as_federated_catalog()

workspace_client.connections.create.assert_called_with(
Expand All @@ -107,3 +109,22 @@ def test_already_existing_connection():
connection_name='a',
options={"authorized_paths": 's3://b/c/d,s3://e/f/g'},
)


def test_hms_federation_enabler():
installation = MockInstallation(
{
'config.yml': {
'inventory_database': 'ucx',
'connect': {
'host': 'host',
'token': 'token',
},
}
}
)
hmse = HiveMetastoreFederationEnabler(installation)
hmse.enable()

config = installation.load(WorkspaceConfig)
assert config.enable_hms_federation is True