diff --git a/src/databricks/labs/ucx/providers/mixins/fixtures.py b/src/databricks/labs/ucx/providers/mixins/fixtures.py index 8685c9476a..1b8503f4d8 100644 --- a/src/databricks/labs/ucx/providers/mixins/fixtures.py +++ b/src/databricks/labs/ucx/providers/mixins/fixtures.py @@ -11,6 +11,7 @@ from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.core import DatabricksError from databricks.sdk.service import compute, iam, jobs, pipelines, workspace +from databricks.sdk.service.sql import CreateWarehouseRequestWarehouseType _LOG = logging.getLogger(__name__) @@ -473,6 +474,36 @@ def create(**kwargs) -> pipelines.CreatePipelineResponse: yield from factory("delta live table", create, lambda item: ws.pipelines.delete(item.pipeline_id)) +@pytest.fixture +def make_warehouse(ws, make_random): + def create( + *, + warehouse_name: str | None = None, + warehouse_type: CreateWarehouseRequestWarehouseType | None = None, + cluster_size: str | None = None, + max_num_clusters: int = 1, + enable_serverless_compute: bool = False, + **kwargs, + ): + if warehouse_name is None: + warehouse_name = f"sdk-{make_random(4)}" + if warehouse_type is None: + warehouse_type = CreateWarehouseRequestWarehouseType.PRO + if cluster_size is None: + cluster_size = "2X-Small" + + return ws.warehouses.create( + name=warehouse_name, + cluster_size=cluster_size, + warehouse_type=warehouse_type, + max_num_clusters=max_num_clusters, + enable_serverless_compute=enable_serverless_compute, + **kwargs, + ) + + yield from factory("warehouse", create, lambda item: ws.warehouses.delete(item.id)) + + def load_debug_env_if_runs_from_ide(key) -> bool: if not _is_in_debug(): return False diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index e0481460fe..76e62f0ad9 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,10 +13,6 @@ from databricks.sdk.service.iam import AccessControlRequest, PermissionLevel from databricks.sdk.service.ml import CreateExperimentResponse, ModelDatabricks from databricks.sdk.service.ml import PermissionLevel as ModelPermissionLevel -from databricks.sdk.service.sql import ( - CreateWarehouseRequestWarehouseType, - GetWarehouseResponse, -) from databricks.sdk.service.workspace import ObjectInfo, ObjectType from databricks.labs.ucx.config import InventoryTable @@ -42,7 +38,6 @@ NUM_TEST_CLUSTER_POLICIES = int(os.environ.get("NUM_TEST_CLUSTER_POLICIES", 3)) NUM_TEST_EXPERIMENTS = int(os.environ.get("NUM_TEST_EXPERIMENTS", 3)) NUM_TEST_MODELS = int(os.environ.get("NUM_TEST_MODELS", 3)) -NUM_TEST_WAREHOUSES = int(os.environ.get("NUM_TEST_WAREHOUSES", 3)) NUM_TEST_TOKENS = int(os.environ.get("NUM_TEST_TOKENS", 3)) NUM_THREADS = int(os.environ.get("NUM_TEST_THREADS", 20)) @@ -352,41 +347,6 @@ def models(ws: WorkspaceClient, env: EnvironmentInfo) -> list[ModelDatabricks]: logger.debug("Test models deleted") -@pytest.fixture -def warehouses(ws: WorkspaceClient, env: EnvironmentInfo) -> list[GetWarehouseResponse]: - logger.debug("Creating warehouses") - - creators = [ - partial( - ws.warehouses.create, - name=f"{env.test_uid}-test-{i}", - cluster_size="2X-Small", - warehouse_type=CreateWarehouseRequestWarehouseType.PRO, - max_num_clusters=1, - enable_serverless_compute=False, - ) - for i in range(NUM_TEST_WAREHOUSES) - ] - - test_warehouses: list[GetWarehouseResponse] = Threader(creators).run() - - _set_random_permissions( - test_warehouses, - "id", - RequestObjectType.SQL_WAREHOUSES, - env, - ws, - permission_levels=[PermissionLevel.CAN_USE, PermissionLevel.CAN_MANAGE], - ) - - yield test_warehouses - - logger.debug("Deleting test warehouses") - executables = [partial(ws.warehouses.delete, w.id) for w in test_warehouses] - Threader(executables).run() - logger.debug("Test warehouses deleted") - - @pytest.fixture def tokens(ws: WorkspaceClient, env: EnvironmentInfo) -> list[AccessControlRequest]: logger.debug("Adding token-level permissions to groups") @@ -465,7 +425,6 @@ def verifiable_objects( cluster_policies, experiments, models, - warehouses, tokens, workspace_objects, ) -> list[tuple[list, str, RequestObjectType | None]]: @@ -475,7 +434,6 @@ def verifiable_objects( (cluster_policies, "policy_id", RequestObjectType.CLUSTER_POLICIES), (experiments, "experiment_id", RequestObjectType.EXPERIMENTS), (models, "id", RequestObjectType.REGISTERED_MODELS), - (warehouses, "id", RequestObjectType.SQL_WAREHOUSES), ] yield _verifiable_objects diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index ef32640f54..5fe6a73561 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -140,6 +140,8 @@ def test_e2e( make_pipeline_permissions, make_secret_scope, make_secret_scope_acl, + make_warehouse, + make_warehouse_permissions, ): logger.debug(f"Test environment: {env.test_uid}") ws_group = env.groups[0][0] @@ -190,6 +192,16 @@ def test_e2e( make_secret_scope_acl(scope=scope, principal=ws_group.display_name, permission=workspace.AclPermission.WRITE) verifiable_objects.append(([scope], "secret_scopes", None)) + warehouse = make_warehouse() + make_warehouse_permissions( + object_id=warehouse.id, + permission_level=random.choice([PermissionLevel.CAN_USE, PermissionLevel.CAN_MANAGE]), + group_name=ws_group.display_name, + ) + verifiable_objects.append( + ([warehouse], "id", RequestObjectType.SQL_WAREHOUSES), + ) + config = MigrationConfig( connect=ConnectConfig.from_databricks_config(ws.config), inventory=InventoryConfig(table=inventory_table),