Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ add workspace-level objects support #59

Merged
merged 6 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# macos

.DS_Store
*.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,10 @@ Security:

Workspace:

- [ ] Notebooks in the Workspace FS
- [ ] Directories in the Workspace FS
- [ ] Files in the Workspace FS

Repos:

- [ ] User-level Repos
- [ ] Org-level Repos
- [x] Notebooks
- [x] Directories
- [x] Repos
- [x] Files

Data access:

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"PyYAML>=6.0.0,<7.0.0",
"ratelimit>=2.2.1,<3.0.0",
"pandas>=2.0.3,<3.0.0",
"python-dotenv>=1.0.0,<=2.0.0"
"python-dotenv>=1.0.0,<=2.0.0",
"tenacity>=8.2.2,<9.0.0",
]

[project.optional-dependencies]
Expand Down
75 changes: 74 additions & 1 deletion src/uc_migration_toolkit/managers/inventory/inventorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@

from databricks.sdk.core import DatabricksError
from databricks.sdk.service.iam import AccessControlResponse, ObjectPermissions
from databricks.sdk.service.workspace import AclItem, SecretScope
from databricks.sdk.service.workspace import (
AclItem,
ObjectInfo,
ObjectType,
SecretScope,
)

from uc_migration_toolkit.managers.inventory.listing import WorkspaceListing
from uc_migration_toolkit.managers.inventory.types import (
AclItemsContainer,
LogicalObjectType,
Expand Down Expand Up @@ -170,3 +176,70 @@ def inventorize(self) -> list[PermissionsInventoryItem]:

def preload(self):
pass


class WorkspaceInventorizer(BaseInventorizer[InventoryObject]):
def __init__(self):
self.listing = WorkspaceListing(
provider.ws,
num_threads=config_provider.config.num_threads,
with_directories=False,
rate_limit=config_provider.config.rate_limit,
)

def preload(self):
pass

@staticmethod
def __convert_object_type_to_request_type(_object: ObjectInfo) -> RequestObjectType | None:
match _object.object_type:
case ObjectType.NOTEBOOK:
return RequestObjectType.NOTEBOOKS
case ObjectType.DIRECTORY:
return RequestObjectType.DIRECTORIES
case ObjectType.LIBRARY:
return None
case ObjectType.REPO:
return RequestObjectType.REPOS
case ObjectType.FILE:
return RequestObjectType.FILES
# silent handler for experiments - they'll be inventorized by the experiments manager
case None:
return None

@staticmethod
def __convert_request_object_type_to_logical_type(request_object_type: RequestObjectType) -> LogicalObjectType:
match request_object_type:
case RequestObjectType.NOTEBOOKS:
return LogicalObjectType.NOTEBOOK
case RequestObjectType.DIRECTORIES:
return LogicalObjectType.DIRECTORY
case RequestObjectType.REPOS:
return LogicalObjectType.REPO
case RequestObjectType.FILES:
return LogicalObjectType.FILE

def _convert_result_to_permission_item(self, _object: ObjectInfo) -> PermissionsInventoryItem | None:
request_object_type = self.__convert_object_type_to_request_type(_object)
if not request_object_type:
return
else:
permissions = provider.ws.permissions.get(
request_object_type=request_object_type, request_object_id=_object.object_id
)

inventory_item = PermissionsInventoryItem(
object_id=str(_object.object_id),
logical_object_type=self.__convert_request_object_type_to_logical_type(request_object_type),
request_object_type=request_object_type,
raw_object_permissions=json.dumps(permissions.as_dict()),
)
return inventory_item

def inventorize(self) -> list[PermissionsInventoryItem]:
self.listing.walk("/")
executables = [partial(self._convert_result_to_permission_item, _object) for _object in self.listing.results]
results = ThreadedExecution[PermissionsInventoryItem | None](executables).run()
results = [result for result in results if result]
logger.info(f"Permissions fetched for {len(results)} workspace objects")
return results
64 changes: 63 additions & 1 deletion src/uc_migration_toolkit/managers/inventory/listing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import datetime as dt
from collections.abc import Iterator
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait

from databricks.sdk.service.ml import ModelDatabricks
from databricks.sdk.service.workspace import ObjectType
from ratelimit import limits, sleep_and_retry

from uc_migration_toolkit.providers.client import provider
from uc_migration_toolkit.config import RateLimitConfig
from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient, provider
from uc_migration_toolkit.providers.config import provider as config_provider
from uc_migration_toolkit.providers.logger import logger


class CustomListing:
Expand All @@ -15,3 +22,58 @@ def list_models() -> Iterator[ModelDatabricks]:
for model in provider.ws.model_registry.list_models():
model_with_id = provider.ws.model_registry.get_model(model.name).registered_model_databricks
yield model_with_id


class WorkspaceListing:
def __init__(
self,
ws: ImprovedWorkspaceClient,
num_threads: int,
*,
with_directories: bool = True,
rate_limit: RateLimitConfig | None = None,
):
self.start_time = None
self._ws = ws
self.results = []
self._num_threads = num_threads
self._with_directories = with_directories
self._counter = 0
self._rate_limit = rate_limit if rate_limit else config_provider.config.rate_limit

@sleep_and_retry
@limits(calls=self._rate_limit.max_requests_per_period, period=self._rate_limit.period_in_seconds)
def _rate_limited_listing(path: str) -> Iterator[ObjectType]:
return self._ws.workspace.list(path=path, recursive=False)

self._rate_limited_listing = _rate_limited_listing

def _progress_report(self, _):
self._counter += 1
measuring_time = dt.datetime.now()
delta_from_start = measuring_time - self.start_time
rps = self._counter / delta_from_start.total_seconds()
if self._counter % 10 == 0:
logger.info(
f"Made {self._counter} workspace listing calls, "
f"collected {len(self.results)} objects, rps: {rps:.3f}/sec"
)

def _walk(self, _path: str):
futures = []
with ThreadPoolExecutor(self._num_threads) as executor:
for _obj in self._rate_limited_listing(_path):
if _obj.object_type == ObjectType.DIRECTORY:
if self._with_directories:
self.results.append(_obj)
future = executor.submit(self._walk, _obj.path)
future.add_done_callback(self._progress_report)
futures.append(future)
else:
self.results.append(_obj)
wait(futures, return_when=ALL_COMPLETED)

def walk(self, path: str):
self.start_time = dt.datetime.now()
self._walk(path)
self._progress_report(None) # report the final progress
64 changes: 45 additions & 19 deletions src/uc_migration_toolkit/managers/inventory/permissions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import random
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Literal

from databricks.sdk.service.iam import AccessControlRequest, Group, ObjectPermissions
from databricks.sdk.service.workspace import AclItem as SdkAclItem
from tenacity import retry, stop_after_attempt, wait_fixed, wait_random

from uc_migration_toolkit.managers.group import MigrationGroupsProvider
from uc_migration_toolkit.managers.inventory.inventorizer import (
SecretScopeInventorizer,
StandardInventorizer,
TokensAndPasswordsInventorizer,
WorkspaceInventorizer,
)
from uc_migration_toolkit.managers.inventory.listing import CustomListing
from uc_migration_toolkit.managers.inventory.table import InventoryTableManager
Expand All @@ -23,7 +27,7 @@
from uc_migration_toolkit.providers.client import provider
from uc_migration_toolkit.providers.config import provider as config_provider
from uc_migration_toolkit.providers.logger import logger
from uc_migration_toolkit.utils import ThreadedExecution
from uc_migration_toolkit.utils import ThreadedExecution, safe_get_acls


@dataclass
Expand Down Expand Up @@ -98,6 +102,7 @@ def get_inventorizers():
id_attribute="id",
),
SecretScopeInventorizer(),
WorkspaceInventorizer(),
]

def inventorize_permissions(self):
Expand Down Expand Up @@ -197,27 +202,49 @@ def _prepare_new_permission_request(
)

@staticmethod
def _permission_applicator(request_payload: PermissionRequestPayload | SecretsPermissionRequestPayload):
if isinstance(request_payload, PermissionRequestPayload):
provider.ws.permissions.update(
request_object_type=request_payload.request_object_type,
request_object_id=request_payload.object_id,
access_control_list=request_payload.access_control_list,
@retry(wait=wait_fixed(1) + wait_random(0, 2), stop=stop_after_attempt(5))
def _scope_permissions_applicator(request_payload: SecretsPermissionRequestPayload):
# TODO: rewrite and generalize this
for _acl_item in request_payload.access_control_list:
# this request will create OR update the ACL for the given principal
# it means that the access_control_list should only keep records required for update
provider.ws.secrets.put_acl(
scope=request_payload.object_id, principal=_acl_item.principal, permission=_acl_item.permission
)
elif isinstance(request_payload, SecretsPermissionRequestPayload):
for _acl_item in request_payload.access_control_list:
# this request will create OR update the ACL for the given principal
# it means that the access_control_list should only keep records required for update
provider.ws.secrets.put_acl(
scope=request_payload.object_id, principal=_acl_item.principal, permission=_acl_item.permission
logger.debug(f"Applied new permissions for scope {request_payload.object_id}: {_acl_item}")
# in-flight check for the applied permissions
# the api might be inconsistent, therefore we need to check that the permissions were applied
for _ in range(3):
time.sleep(random.random() * 2)
applied_acls = safe_get_acls(
provider.ws, scope_name=request_payload.object_id, group_name=_acl_item.principal
)
assert applied_acls, f"Failed to apply permissions for {_acl_item.principal}"
assert applied_acls.permission == _acl_item.permission, (
f"Failed to apply permissions for {_acl_item.principal}. "
f"Expected: {_acl_item.permission}. Actual: {applied_acls.permission}"
)

@staticmethod
def _standard_permissions_applicator(request_payload: PermissionRequestPayload):
provider.ws.permissions.update(
request_object_type=request_payload.request_object_type,
request_object_id=request_payload.object_id,
access_control_list=request_payload.access_control_list,
)

def applicator(self, request_payload: PermissionRequestPayload | SecretsPermissionRequestPayload):
if isinstance(request_payload, PermissionRequestPayload):
self._standard_permissions_applicator(request_payload)
elif isinstance(request_payload, SecretsPermissionRequestPayload):
self._scope_permissions_applicator(request_payload)
else:
logger.warning(f"Unsupported logical object type {request_payload}")
logger.warning(f"Unsupported payload type {type(request_payload)}")

def _apply_permissions_in_parallel(
self, requests: list[PermissionRequestPayload | SecretsPermissionRequestPayload]
):
executables = [partial(self._permission_applicator, payload) for payload in requests]
executables = [partial(self.applicator, payload) for payload in requests]
execution = ThreadedExecution[None](executables)
execution.run()

Expand All @@ -230,12 +257,11 @@ def apply_group_permissions(
permissions_on_source = self.inventory_table_manager.load_for_groups(
groups=[g.workspace.display_name for g in migration_groups_provider.groups]
)
applicable_permissions = [
permission_payloads: list[PermissionRequestPayload | SecretsPermissionRequestPayload] = [
self._prepare_new_permission_request(item, migration_groups_provider, destination=destination)
for item in permissions_on_source
]
logger.info(f"Applying {len(permission_payloads)} permissions")

logger.info(f"Applying {len(applicable_permissions)} permissions")

self._apply_permissions_in_parallel(requests=applicable_permissions)
self._apply_permissions_in_parallel(requests=permission_payloads)
logger.info("All permissions were applied")
4 changes: 4 additions & 0 deletions src/uc_migration_toolkit/managers/inventory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __repr__(self):


class LogicalObjectType(StrEnum):
FILE = "FILE"
REPO = "REPO"
DIRECTORY = "DIRECTORY"
NOTEBOOK = "NOTEBOOK"
SECRET_SCOPE = "SECRET_SCOPE"
PASSWORD = "PASSWORD"
TOKEN = "TOKEN"
Expand Down
14 changes: 7 additions & 7 deletions src/uc_migration_toolkit/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import concurrent
import datetime as dt
import enum
import json
from collections.abc import Callable
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor
from typing import Generic, TypeVar

from databricks.sdk.service.workspace import AclItem
from ratelimit import limits, sleep_and_retry

from uc_migration_toolkit.config import RateLimitConfig
from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient
from uc_migration_toolkit.providers.config import provider as config_provider
from uc_migration_toolkit.providers.logger import logger

Expand Down Expand Up @@ -107,9 +108,8 @@ class WorkspaceLevelEntitlement(StrEnum):
ALLOW_INSTANCE_POOL_CREATE = "allow-instance-pool-create"


# TODO: using this because SDK doesn't know how to properly write enums, highlight this to the SDK team
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, enum.Enum):
return obj.name
return json.JSONEncoder.default(self, obj)
def safe_get_acls(ws: ImprovedWorkspaceClient, scope_name: str, group_name: str) -> AclItem | None:
all_acls = ws.secrets.list_acls(scope=scope_name)
for acl in all_acls:
if acl.principal == group_name:
return acl
Loading