Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crawler for RunSubmit API usages from External Orchestrators (ADF/Airflow) #364

Closed
wants to merge 7 commits into from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
]
dependencies = [
"databricks-sdk~=0.9.0",
"databricks-sdk~=0.10.0",
"PyYAML>=6.0.0,<7.0.0",
]

Expand Down
87 changes: 77 additions & 10 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import datetime
import json
import logging
import re
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import ClusterDetails, ClusterSource
from databricks.sdk.service.compute import ClusterSource, ClusterDetails
from databricks.sdk.core import DatabricksError
from databricks.sdk.service.compute import ClusterSource
from databricks.sdk.service.compute import ClusterDetails, ClusterSource
from databricks.sdk.service.jobs import BaseJob

from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)

INCOMPATIBLE_SPARK_CONFIG_KEYS = [
"spark.databricks.passthrough.enabled",
"spark.hadoop.javax.jdo.option.ConnectionURL",
Expand Down Expand Up @@ -53,6 +61,15 @@ class PipelineInfo:
failures: str


@dataclass
class JobRunInfo:
run_id: int
run_type: str
cluster_key: str | None
spark_version: str
data_security_mode: str


def _azure_sp_conf_present_check(config: dict) -> bool:
for key in config.keys():
for conf in _AZURE_SP_CONF:
Expand Down Expand Up @@ -145,12 +162,15 @@ def _assess_clusters(self, all_clusters):

# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
policy = self._ws.cluster_policies.get(cluster.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
try:
policy = self._ws.cluster_policies.get(cluster.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
except DatabricksError as err:
logger.warning(f"Error retrieving cluster policy {cluster.policy_id}. Error: {err}")

cluster_info.failures = json.dumps(failures)
if len(failures) > 0:
Expand Down Expand Up @@ -220,12 +240,15 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> list[JobI

# Checking if Azure cluster config is present in cluster policies
if cluster_config.policy_id:
policy = self._ws.cluster_policies.get(cluster_config.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
try:
policy = self._ws.cluster_policies.get(cluster_config.policy_id)
if _azure_sp_conf_present_check(json.loads(policy.definition)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
job_assessment[job.job_id].add(f"{_AZURE_SP_CONF_FAILURE_MSG} Job cluster.")
except DatabricksError as err:
logger.warning(f"Error retrieving cluster policy {cluster_config.policy_id}. Error: {err}")

for job_key in job_details.keys():
job_details[job_key].failures = json.dumps(list(job_assessment[job_key]))
Expand All @@ -239,3 +262,47 @@ def snapshot(self) -> list[ClusterInfo]:
def _try_fetch(self) -> list[ClusterInfo]:
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
yield JobInfo(*row)


class JobsRunCrawler(CrawlerBase):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "job_runs")
self._ws = ws

def _crawl(self) -> list[JobRunInfo]:
no_of_days_back = datetime.timedelta(days=30) # todo make configurable in yaml?
start_time_from = datetime.datetime.now() - no_of_days_back
# todo figure out if we need to specify a default timezone
all_job_runs = list(
self._ws.jobs.list_runs(start_time_from=start_time_from, start_time_to=datetime.datetime.now())
)
all_clusters: dict[str, ClusterDetails] = {c.cluster_id: c for c in self._ws.clusters.list()}
return self._assess_job_runs(all_clusters, all_job_runs)

def _assess_job_runs(self, all_clusters, all_job_runs):
all_job_run_info = []
job_runs_without_job_id = list(filter(lambda jr: jr.job_id is None, all_job_runs))
for job_run in job_runs_without_job_id:
for job_run_cluster in job_run.job_clusters:
cluster_key = job_run_cluster.job_cluster_key
cluster: ClusterDetails = (
all_clusters[cluster_key] if cluster_key is not None else job_run_cluster.new_cluster
)
spark_version = cluster.spark_version
data_security_mode = cluster.data_security_mode
job_run_info = JobRunInfo(
run_id=job_run.run_id,
run_type=str(job_run.run_type.value),
cluster_key=cluster_key,
spark_version=spark_version,
data_security_mode=str(data_security_mode.value),
)
all_job_run_info.append(job_run_info)
return all_job_run_info

def snapshot(self) -> list[JobRunInfo]:
return self._snapshot(self._try_fetch, self._crawl)

def _try_fetch(self) -> list[JobRunInfo]:
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
yield JobRunInfo(*row)
9 changes: 6 additions & 3 deletions src/databricks/labs/ucx/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,12 @@ def migrate_permissions(cfg: WorkspaceConfig):
See [interactive tutorial here](https://app.getreprise.com/launch/myM3VNn/)."""
toolkit = GroupMigrationToolkit(cfg)
toolkit.prepare_environment()
toolkit.apply_permissions_to_backup_groups()
toolkit.replace_workspace_groups_with_account_groups()
toolkit.apply_permissions_to_account_groups()
if toolkit.has_groups():
toolkit.apply_permissions_to_backup_groups()
toolkit.replace_workspace_groups_with_account_groups()
toolkit.apply_permissions_to_account_groups()
else:
logger.info("Skipping group migration as no groups were found.")


@task("migrate-groups-cleanup", depends_on=[migrate_permissions])
Expand Down
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/workspace_access/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def prepare_groups_in_environment(self):
self._set_migration_groups(group_names)
logger.info("Environment prepared successfully")

def has_groups(self) -> bool:
return len(self._migration_state.groups) > 0

@property
def migration_groups_provider(self) -> GroupMigrationState:
assert len(self._migration_state.groups) > 0, "Migration groups were not loaded or initialized"
Expand Down
3 changes: 3 additions & 0 deletions src/databricks/labs/ucx/workspace_access/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def _configure_logger(level: str):
ucx_logger = logging.getLogger("databricks.labs.ucx")
ucx_logger.setLevel(level)

def has_groups(self) -> bool:
return self._group_manager.has_groups()

def prepare_environment(self):
self._group_manager.prepare_groups_in_environment()

Expand Down
130 changes: 128 additions & 2 deletions tests/unit/assessment/test_assessment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,47 @@
from unittest.mock import Mock

from databricks.sdk.service.compute import AutoScale, ClusterDetails, ClusterSource
from databricks.sdk.service.jobs import BaseJob, JobSettings, NotebookTask, Task
from databricks.sdk.service.compute import (
AutoScale,
ClusterDetails,
ClusterSource,
DataSecurityMode,
)
from databricks.sdk.service.jobs import (
BaseJob,
BaseRun,
JobCluster,
JobSettings,
NotebookTask,
RunType,
Task,
)
from databricks.sdk.service import compute
from databricks.sdk.service.compute import AutoScale, ClusterDetails, ClusterSource, DataSecurityMode
from databricks.sdk.service.jobs import BaseJob, JobSettings, NotebookTask, Task, RunType, BaseRun, JobCluster, \
ClusterSpec
from databricks.sdk.core import DatabricksError
from databricks.sdk.service.compute import (
AutoScale,
ClusterDetails,
ClusterSource,
DataSecurityMode,
)
from databricks.sdk.service.jobs import (
BaseJob,
BaseRun,
JobCluster,
JobSettings,
NotebookTask,
RunType,
Task,
)
from databricks.sdk.service.pipelines import PipelineState, PipelineStateInfo

from databricks.labs.ucx.assessment.crawlers import (
ClustersCrawler,
JobRunInfo,
JobsCrawler,
JobsRunCrawler,
PipelineInfo,
PipelinesCrawler,
)
Expand Down Expand Up @@ -395,6 +430,24 @@ def test_cluster_assessment_cluster_policy_no_spark_conf(mocker):
assert result_set1[0].success == 1


def test_cluster_assessment_cluster_policy_not_found(mocker):
sample_clusters1 = [
ClusterDetails(
cluster_name="cluster1",
autoscale=AutoScale(min_workers=1, max_workers=6),
spark_context_id=5134472582179565315,
spark_env_vars=None,
policy_id="D96308F1BF0003A8",
spark_version="13.3.x-cpu-ml-scala2.12",
cluster_id="0915-190044-3dqy6751",
)
]
ws = Mock()
ws.cluster_policies.get.side_effect = DatabricksError(error="NO_POLICY", error_code="NO_POLICY")
crawler = ClustersCrawler(ws, MockBackend(), "ucx")._assess_clusters(sample_clusters1)
list(crawler)


def test_pipeline_assessment_with_config(mocker):
sample_pipelines = [
PipelineStateInfo(
Expand Down Expand Up @@ -467,3 +520,76 @@ def test_pipeline_snapshot_with_config():

assert len(result_set) == 1
assert result_set[0].success == 1


def test_job_run_crawler():
"""
Simple test to validate that JobsRunCrawler
- returns a list of JobRunInfo objects
- of appropriate size
- with expected values
"""
sample_job_run_infos = [
JobRunInfo(
run_id=123456789,
run_type=RunType.SUBMIT_RUN.value,
cluster_key=None,
spark_version="11.3.x-scala2.12",
data_security_mode=DataSecurityMode.NONE.value,
),
JobRunInfo(
run_id=123456790,
run_type=RunType.WORKFLOW_RUN.value,
cluster_key=None,
spark_version="11.3.x-scala2.12",
data_security_mode=DataSecurityMode.SINGLE_USER.value,
),
]
mock_ws = Mock()

crawler = JobsRunCrawler(mock_ws, MockBackend(), "ucx")

crawler._try_fetch = Mock(return_value=[])
crawler._crawl = Mock(return_value=sample_job_run_infos)

result_set = crawler.snapshot()

assert len(result_set) == 2
assert result_set[0].data_security_mode == DataSecurityMode.NONE.value


def test_job_run_crawler_filters_runs_with_job_id():
"""
Test to validate
- job runs with a job id are not included in the result set
"""
sample_job_runs = [
BaseRun(
job_id=12345678910,
run_id=123456789,
run_type=RunType.SUBMIT_RUN,
job_clusters=[JobCluster(job_cluster_key="my_job_cluster")],
)
]
sample_clusters = [
ClusterDetails(
autoscale=AutoScale(min_workers=1, max_workers=6),
spark_context_id=5134472582179565315,
spark_env_vars=None,
spark_version="11.3.x-scala2.12",
cluster_id="my_job_cluster",
cluster_source=ClusterSource.JOB,
)
]
mock_ws = Mock()

mock_ws.jobs.list_runs = Mock(return_value=sample_job_runs)
mock_ws.clusters.list = Mock(return_value=sample_clusters)

crawler = JobsRunCrawler(mock_ws, MockBackend(), "ucx")

crawler._try_fetch = Mock(return_value=[])

result_set = crawler.snapshot()

assert len(result_set) == 0
13 changes: 13 additions & 0 deletions tests/unit/workspace_access/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,19 @@ def test_prepare_groups_in_environment_with_conf_in_auto_mode_should_populate_mi
assert manager._migration_state.groups == [group_info]


def test_prepare_groups_in_environment_with_no_groups():
client = Mock()
client.groups.list.return_value = iter([])
client.api_client.do.return_value = {
"Resources": [],
}

group_conf = GroupsConfig(auto=True)
manager = GroupManager(client, group_conf)
manager.prepare_groups_in_environment()
assert not manager.has_groups()


def test_replace_workspace_groups_with_account_groups_should_call_delete_and_do():
client = Mock()

Expand Down