diff --git a/pyproject.toml b/pyproject.toml index d4c4b5aeaf..d331521706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ dependencies = [ [tool.hatch.envs.lint.scripts] fmt = [ "black .", - "ruff --fix .", + "ruff . --fix", "isort ." ] verify = [ diff --git a/src/uc_migration_toolkit/config.py b/src/uc_migration_toolkit/config.py index 65e7e3e52e..6f5f528e3e 100644 --- a/src/uc_migration_toolkit/config.py +++ b/src/uc_migration_toolkit/config.py @@ -31,6 +31,7 @@ def __post_init__(self): raise ValueError(msg) +# TODO: replace with databricks.sdk.core.Config @dataclass class InventoryConfig: table: InventoryTable @@ -50,9 +51,26 @@ class ConnectConfig: azure_environment: str | None = None cluster_id: str | None = None profile: str | None = None - debug_headers: bool = False + debug_headers: bool | None = False rate_limit: int | None = None + @staticmethod + def from_databricks_config(cfg: Config) -> "ConnectConfig": + return ConnectConfig( + host=cfg.host, + token=cfg.token, + client_id=cfg.client_id, + client_secret=cfg.client_secret, + azure_client_id=cfg.azure_client_id, + azure_tenant_id=cfg.azure_tenant_id, + azure_client_secret=cfg.azure_client_secret, + azure_environment=cfg.azure_environment, + cluster_id=cfg.cluster_id, + profile=cfg.profile, + debug_headers=cfg.debug_headers, + rate_limit=cfg.rate_limit, + ) + @dataclass class MigrationConfig: diff --git a/src/uc_migration_toolkit/providers/spark.py b/src/uc_migration_toolkit/providers/spark.py index 95fcef1950..d9f7dc68d3 100644 --- a/src/uc_migration_toolkit/providers/spark.py +++ b/src/uc_migration_toolkit/providers/spark.py @@ -15,7 +15,7 @@ def _initialize_spark(ws: WorkspaceClient) -> SparkSession: from databricks.sdk.runtime import spark return spark - except ValueError: + except ImportError: logger.info("Using DB Connect") from databricks.connect import DatabricksSession diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fc125372bf..f569e47e20 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,10 +1,14 @@ import io import json +import logging import os +import pathlib import random +import sys import uuid from functools import partial +import databricks.sdk.core import pytest from _pytest.fixtures import SubRequest from databricks.sdk import AccountClient @@ -33,7 +37,14 @@ ObjectType, SecretScope, ) -from utils import ( + +from uc_migration_toolkit.config import InventoryTable +from uc_migration_toolkit.managers.inventory.types import RequestObjectType +from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.utils import Request, ThreadedExecution + +from .utils import ( EnvironmentInfo, InstanceProfile, WorkspaceObjects, @@ -42,16 +53,9 @@ _get_basic_job_cluster, _get_basic_task, _set_random_permissions, - initialize_env, ) -from uc_migration_toolkit.config import InventoryTable -from uc_migration_toolkit.managers.inventory.types import RequestObjectType -from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient -from uc_migration_toolkit.providers.logger import logger -from uc_migration_toolkit.utils import Request, ThreadedExecution - -initialize_env() +logging.getLogger("databricks.sdk").setLevel(logging.INFO) NUM_TEST_GROUPS = int(os.environ.get("NUM_TEST_GROUPS", 5)) NUM_TEST_INSTANCE_PROFILES = int(os.environ.get("NUM_TEST_INSTANCE_PROFILES", 3)) @@ -72,14 +76,45 @@ Threader = partial(ThreadedExecution, num_threads=NUM_THREADS) +def account_host(self: databricks.sdk.core.Config) -> str: + if self.is_azure: + return "https://accounts.azuredatabricks.net" + elif self.is_gcp: + return "https://accounts.gcp.databricks.com/" + else: + return "https://accounts.cloud.databricks.com" + + +def _load_debug_env_if_runs_from_ide(key) -> bool: + if not _is_in_debug(): + return False + conf_file = pathlib.Path.home() / ".databricks/debug-env.json" + with conf_file.open("r") as f: + conf = json.load(f) + if key not in conf: + msg = f"{key} not found in ~/.databricks/debug-env.json" + raise KeyError(msg) + for k, v in conf[key].items(): + os.environ[k] = v + return True + + +def _is_in_debug() -> bool: + return os.path.basename(sys.argv[0]) in [ + "_jb_pytest_runner.py", + "testlauncher.py", + ] + + @pytest.fixture(scope="session") def ws() -> ImprovedWorkspaceClient: # Use variables from Unified Auth # See https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html + _load_debug_env_if_runs_from_ide("ucws") return ImprovedWorkspaceClient() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def acc(ws) -> AccountClient: # TODO: move to SDK def account_host(cfg: Config) -> str: @@ -95,38 +130,39 @@ def account_host(cfg: Config) -> str: return AccountClient(host=account_host(ws.config)) -@pytest.fixture(scope="session", autouse=True) -def dbconnect(ws: ImprovedWorkspaceClient): +@pytest.fixture(scope="session") +def dbconnect_cluster_id(ws: ImprovedWorkspaceClient) -> str: dbc_cluster = next(filter(lambda c: c.cluster_name == DB_CONNECT_CLUSTER_NAME, ws.clusters.list()), None) if dbc_cluster: logger.debug(f"Integration testing cluster {DB_CONNECT_CLUSTER_NAME} already exists, skipping it's creation") - else: - logger.debug("Creating a cluster for integration testing") - request = { - "cluster_name": DB_CONNECT_CLUSTER_NAME, - "spark_version": "13.2.x-scala2.12", - "instance_pool_id": os.environ["TEST_POOL_ID"], - "driver_instance_pool_id": os.environ["TEST_POOL_ID"], - "num_workers": 0, - "spark_conf": {"spark.master": "local[*, 4]", "spark.databricks.cluster.profile": "singleNode"}, - "custom_tags": { - "ResourceClass": "SingleNode", - }, - "data_security_mode": "SINGLE_USER", - "autotermination_minutes": 180, - "runtime_engine": "PHOTON", - } - - dbc_cluster = ws.clusters.create(spark_version="13.2.x-scala2.12", request=Request(request)) - - logger.debug(f"Cluster {dbc_cluster.cluster_id} created") - - os.environ["DATABRICKS_CLUSTER_ID"] = dbc_cluster.cluster_id - yield + return dbc_cluster.cluster_id + + logger.debug("Creating a cluster for integration testing") + spark_version = ws.clusters.select_spark_version(latest=True) + request = { + "cluster_name": DB_CONNECT_CLUSTER_NAME, + "spark_version": spark_version, + "instance_pool_id": os.environ["TEST_INSTANCE_POOL_ID"], + "driver_instance_pool_id": os.environ["TEST_INSTANCE_POOL_ID"], + "num_workers": 0, + "spark_conf": {"spark.master": "local[*, 4]", "spark.databricks.cluster.profile": "singleNode"}, + "custom_tags": { + "ResourceClass": "SingleNode", + }, + "data_security_mode": "SINGLE_USER", + "autotermination_minutes": 180, + "runtime_engine": "PHOTON", + } + + dbc_cluster = ws.clusters.create(spark_version=spark_version, request=Request(request)) + logger.debug(f"Cluster {dbc_cluster.cluster_id} created") + + # TODO: pre-create the cluster in the test infra + return dbc_cluster.cluster_id -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def env(ws: ImprovedWorkspaceClient, acc: AccountClient, request: SubRequest) -> EnvironmentInfo: # prepare environment test_uid = f"{UCX_TESTING_PREFIX}_{str(uuid.uuid4())[:8]}" @@ -151,7 +187,9 @@ def _wrapped(*args, **kwargs): silent_delete = error_silencer(ws.groups.delete) temp_cleanups = [ - partial(silent_delete, g.id) for g in ws.groups.list(filter=f"displayName sw 'db-temp-{test_uid}'") + # TODO: this is too heavy for SCIM API, refactor to ID lookup + partial(silent_delete, g.id) + for g in ws.groups.list(filter=f"displayName sw 'db-temp-{test_uid}'") ] new_ws_groups_cleanups = [ partial(silent_delete, g.id) for g in ws.groups.list(filter=f"displayName sw '{test_uid}'") @@ -165,7 +203,7 @@ def _wrapped(*args, **kwargs): yield EnvironmentInfo(test_uid=test_uid, groups=groups) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[InstanceProfile]: logger.debug("Adding test instance profiles") profiles: list[InstanceProfile] = [] @@ -199,7 +237,7 @@ def instance_profiles(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list logger.debug("Test instance profiles deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def instance_pools(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateInstancePoolResponse]: logger.debug("Creating test instance pools") @@ -224,7 +262,7 @@ def instance_pools(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[Cr Threader(executables).run() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def pipelines(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreatePipelineResponse]: logger.debug("Creating test DLT pipelines") @@ -254,7 +292,7 @@ def pipelines(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateP Threader(executables).run() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def jobs(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateResponse]: logger.debug("Creating test jobs") @@ -281,7 +319,7 @@ def jobs(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreateRespon Threader(executables).run() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def cluster_policies(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[CreatePolicyResponse]: logger.debug("Creating test cluster policies") @@ -316,16 +354,16 @@ def cluster_policies(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ Threader(executables).run() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterDetails]: logger.debug("Creating test clusters") creators = [ partial( ws.clusters.create, - spark_version="13.2.x-scala2.12", - instance_pool_id=os.environ["TEST_POOL_ID"], - driver_instance_pool_id=os.environ["TEST_POOL_ID"], + spark_version=ws.clusters.select_spark_version(latest=True), + instance_pool_id=os.environ["TEST_INSTANCE_POOL_ID"], + driver_instance_pool_id=os.environ["TEST_INSTANCE_POOL_ID"], cluster_name=f"{env.test_uid}-test-{i}", num_workers=1, ) @@ -351,7 +389,7 @@ def clusters(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> list[ClusterD logger.debug("Test clusters deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def experiments(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[CreateExperimentResponse]: logger.debug("Creating test experiments") @@ -384,7 +422,7 @@ def experiments(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[Creat logger.debug("Test experiments deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def models(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[ModelDatabricks]: logger.debug("Creating models") @@ -417,7 +455,7 @@ def models(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[ModelDatab logger.debug("Test models deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def warehouses(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[GetWarehouseResponse]: logger.debug("Creating warehouses") @@ -452,7 +490,7 @@ def warehouses(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[GetWar logger.debug("Test warehouses deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def tokens(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[AccessControlRequest]: logger.debug("Adding token-level permissions to groups") @@ -470,7 +508,7 @@ def tokens(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[AccessCont yield token_permissions -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def secret_scopes(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[SecretScope]: logger.debug("Creating test secret scopes") @@ -491,7 +529,7 @@ def secret_scopes(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> list[Sec Threader(executables).run() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def workspace_objects(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> WorkspaceObjects: logger.info(f"Creating test workspace objects under /{env.test_uid}") ws.workspace.mkdirs(f"/{env.test_uid}") @@ -546,7 +584,7 @@ def workspace_objects(ws: ImprovedWorkspaceClient, env: EnvironmentInfo) -> Work logger.debug("Test workspace objects deleted") -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def verifiable_objects( clusters, instance_pools, @@ -577,7 +615,8 @@ def verifiable_objects( @pytest.fixture() -def inventory_table(env: EnvironmentInfo, ws: ImprovedWorkspaceClient) -> InventoryTable: +def inventory_table(env: EnvironmentInfo, ws: ImprovedWorkspaceClient, dbconnect_cluster_id: str) -> InventoryTable: + ws.config.cluster_id = dbconnect_cluster_id table = InventoryTable( catalog="main", database="default", diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index bd3b455289..5393e501f8 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -9,9 +9,9 @@ ) from databricks.sdk.service.workspace import SecretScope from pyspark.errors import AnalysisException -from utils import EnvironmentInfo, WorkspaceObjects from uc_migration_toolkit.config import ( + ConnectConfig, GroupsConfig, InventoryConfig, InventoryTable, @@ -24,6 +24,8 @@ from uc_migration_toolkit.toolkits.group_migration import GroupMigrationToolkit from uc_migration_toolkit.utils import safe_get_acls +from .utils import EnvironmentInfo, WorkspaceObjects + def _verify_group_permissions( objects: list | WorkspaceObjects | None, @@ -168,13 +170,13 @@ def test_e2e( logger.debug(f"Test environment: {env.test_uid}") config = MigrationConfig( + connect=ConnectConfig.from_databricks_config(ws.config), with_table_acls=False, inventory=InventoryConfig(table=inventory_table), groups=GroupsConfig(selected=[g[0].display_name for g in env.groups]), auth=None, log_level="TRACE", ) - logger.debug(f"Starting e2e with config: {config.to_json()}") toolkit = GroupMigrationToolkit(config) toolkit.prepare_environment() diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py new file mode 100644 index 0000000000..4f8fb0dd16 --- /dev/null +++ b/tests/integration/test_jobs.py @@ -0,0 +1,101 @@ +import pytest +from pyspark.errors import AnalysisException + +from uc_migration_toolkit.config import ( + ConnectConfig, + GroupsConfig, + InventoryConfig, + InventoryTable, + MigrationConfig, +) +from uc_migration_toolkit.managers.inventory.types import RequestObjectType +from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient +from uc_migration_toolkit.providers.logger import logger +from uc_migration_toolkit.toolkits.group_migration import GroupMigrationToolkit + +from .test_e2e import _verify_group_permissions, _verify_roles_and_entitlements +from .utils import EnvironmentInfo + + +def test_jobs( + env: EnvironmentInfo, + inventory_table: InventoryTable, + ws: ImprovedWorkspaceClient, + jobs, +): + logger.debug(f"Test environment: {env.test_uid}") + + config = MigrationConfig( + connect=ConnectConfig.from_databricks_config(ws.config), + with_table_acls=False, + inventory=InventoryConfig(table=inventory_table), + groups=GroupsConfig(selected=[g[0].display_name for g in env.groups]), + auth=None, + log_level="TRACE", + ) + toolkit = GroupMigrationToolkit(config) + toolkit.prepare_environment() + + logger.debug("Verifying that the groups were created") + + assert len(ws.groups.list(filter=f"displayName sw '{config.groups.backup_group_prefix}{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + assert len(ws.groups.list(filter=f"displayName sw '{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + assert len(ws.list_account_level_groups(filter=f"displayName sw '{env.test_uid}'")) == len( + toolkit.group_manager.migration_groups_provider.groups + ) + + for _info in toolkit.group_manager.migration_groups_provider.groups: + _ws = ws.groups.get(id=_info.workspace.id) + _backup = ws.groups.get(id=_info.backup.id) + _ws_members = sorted([m.value for m in _ws.members]) + _backup_members = sorted([m.value for m in _backup.members]) + assert _ws_members == _backup_members + + logger.debug("Verifying that the groups were created - done") + + toolkit.cleanup_inventory_table() + + with pytest.raises(AnalysisException): + toolkit.table_manager.spark.catalog.getTable(toolkit.table_manager.config.table.to_spark()) + + toolkit.inventorize_permissions() + + toolkit.apply_permissions_to_backup_groups() + + verifiable_objects = [ + (jobs, "job_id", RequestObjectType.JOBS), + ] + for _objects, id_attribute, request_object_type in verifiable_objects: + _verify_group_permissions(_objects, id_attribute, request_object_type, ws, toolkit, "backup") + + _verify_roles_and_entitlements(toolkit.group_manager.migration_groups_provider, ws, "backup") + + toolkit.replace_workspace_groups_with_account_groups() + + new_groups = list(ws.groups.list(filter=f"displayName sw '{env.test_uid}'", attributes="displayName,meta")) + assert len(new_groups) == len(toolkit.group_manager.migration_groups_provider.groups) + assert all(g.meta.resource_type == "Group" for g in new_groups) + + toolkit.apply_permissions_to_account_groups() + + for _objects, id_attribute, request_object_type in verifiable_objects: + _verify_group_permissions(_objects, id_attribute, request_object_type, ws, toolkit, "account") + + _verify_roles_and_entitlements(toolkit.group_manager.migration_groups_provider, ws, "account") + + toolkit.delete_backup_groups() + + backup_groups = list( + ws.groups.list( + filter=f"displayName sw '{config.groups.backup_group_prefix}{env.test_uid}'", attributes="displayName,meta" + ) + ) + assert len(backup_groups) == 0 + + toolkit.cleanup_inventory_table() diff --git a/dev/init_setup.py b/tests/integration/test_setup.py similarity index 60% rename from dev/init_setup.py rename to tests/integration/test_setup.py index 795ccd2d3f..8dc855f8d1 100644 --- a/dev/init_setup.py +++ b/tests/integration/test_setup.py @@ -1,9 +1,8 @@ from functools import partial -from pathlib import Path +import pytest from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import ComplexValue -from dotenv import load_dotenv from uc_migration_toolkit.providers.logger import logger from uc_migration_toolkit.utils import ThreadedExecution @@ -11,8 +10,9 @@ Threader = partial(ThreadedExecution, num_threads=40) -def _create_user(_ws: WorkspaceClient, uid: str): +def _create_user(ws: WorkspaceClient, uid: str): user_name = f"test-user-{uid}@example.com" + # TODO: listing is expensive for SCIM, better swallow the exception potential_user = list(ws.users.list(filter=f"userName eq '{user_name}'")) if potential_user: logger.debug(f"User {user_name} already exists, skipping its creation") @@ -25,20 +25,7 @@ def _create_user(_ws: WorkspaceClient, uid: str): ) -def _create_users(_ws: WorkspaceClient): +def test_create_users(ws): + pytest.skip("run only in debug") executables = [partial(_create_user, ws, uid) for uid in range(200)] Threader(executables).run() - - -if __name__ == "__main__": - principal_env = Path(__file__).parent.parent / ".env.principal" - if principal_env.exists(): - logger.info("Using credentials provided in .env.principal") - load_dotenv(dotenv_path=principal_env) - - logger.debug("setting up the workspace client") - ws = WorkspaceClient() - user_info = ws.current_user.me() - logger.debug("workspace client is set up") - - _create_users(ws) diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 04b0b94a91..06daa18adf 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -1,7 +1,6 @@ import random from dataclasses import dataclass from functools import partial -from pathlib import Path from typing import Any from databricks.sdk import AccountClient, WorkspaceClient @@ -15,7 +14,6 @@ ) from databricks.sdk.service.jobs import JobCluster, PythonWheelTask, Task from databricks.sdk.service.workspace import ObjectInfo -from dotenv import load_dotenv from uc_migration_toolkit.managers.inventory.types import RequestObjectType from uc_migration_toolkit.providers.client import ImprovedWorkspaceClient @@ -23,16 +21,6 @@ from uc_migration_toolkit.utils import WorkspaceLevelEntitlement -def initialize_env() -> None: - principal_env = Path(__file__).parent.parent.parent / ".env.principal" - - if principal_env.exists(): - logger.debug("Using credentials provided in .env.principal") - load_dotenv(dotenv_path=principal_env) - else: - logger.debug(f"No .env.principal found at {principal_env.absolute()}, using environment variables") - - @dataclass class InstanceProfile: instance_profile_arn: str diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2