From 4707bbcfb6fbf5cf44b6987e62e46122f57cad9f Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Mon, 21 Oct 2024 13:18:37 +0200 Subject: [PATCH] support all cell languages --- src/databricks/labs/ucx/source_code/jobs.py | 39 ++++++++++++++----- .../unit/source_code/linters/test_directfs.py | 24 +++++++++++- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/jobs.py b/src/databricks/labs/ucx/source_code/jobs.py index 5565f1ef11..60594588bb 100644 --- a/src/databricks/labs/ucx/source_code/jobs.py +++ b/src/databricks/labs/ucx/source_code/jobs.py @@ -4,7 +4,7 @@ import shutil import tempfile from abc import ABC, abstractmethod -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, Callable from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timezone @@ -651,14 +651,15 @@ def _collect_from_source( path: Path, inherited_tree: Tree | None, ) -> Iterable[T]: - iterable: Iterable[T] | None = None - if language is CellLanguage.SQL: - iterable = self._collect_from_sql(source) if language is CellLanguage.PYTHON: iterable = self._collect_from_python(source, inherited_tree) - if iterable is None: - logger.warning(f"Language {language.name} not supported yet!") - return + else: + fn: Callable[[str], Iterable[T]] | None = getattr(self, f"_collect_from_{language.name.lower()}", None) + if not fn: + raise ValueError(f"Language {language.name} not supported yet!") + # the below is for disabling a false pylint positive + # pylint: disable=not-callable + iterable = fn(source) src_timestamp = datetime.fromtimestamp(path.stat().st_mtime, timezone.utc) src_id = str(path) for item in iterable: @@ -667,8 +668,28 @@ def _collect_from_source( @abstractmethod def _collect_from_python(self, source: str, inherited_tree: Tree | None) -> Iterable[T]: ... - @abstractmethod - def _collect_from_sql(self, source: str) -> Iterable[T]: ... + def _collect_from_sql(self, _source: str) -> Iterable[T]: + return [] + + def _collect_from_r(self, _source: str) -> Iterable[T]: + logger.warning("Language R not supported yet!") + return [] + + def _collect_from_scala(self, _source: str) -> Iterable[T]: + logger.warning("Language scala not supported yet!") + return [] + + def _collect_from_shell(self, _source: str) -> Iterable[T]: + return [] + + def _collect_from_markdown(self, _source: str) -> Iterable[T]: + return [] + + def _collect_from_run(self, _source: str) -> Iterable[T]: + return [] + + def _collect_from_pip(self, _source: str) -> Iterable[T]: + return [] class DfsaCollectorWalker(_CollectorWalker[DirectFsAccess]): diff --git a/tests/unit/source_code/linters/test_directfs.py b/tests/unit/source_code/linters/test_directfs.py index 02790ad346..4a3c0ba483 100644 --- a/tests/unit/source_code/linters/test_directfs.py +++ b/tests/unit/source_code/linters/test_directfs.py @@ -1,11 +1,18 @@ +from collections.abc import Iterable +from pathlib import Path +from unittest.mock import create_autospec + import pytest -from databricks.labs.ucx.source_code.base import Deprecation, Advice, CurrentSessionState, Failure +from databricks.labs.ucx.source_code.base import Deprecation, Advice, CurrentSessionState, Failure, DirectFsAccess +from databricks.labs.ucx.source_code.graph import DependencyGraph +from databricks.labs.ucx.source_code.jobs import DfsaCollectorWalker from databricks.labs.ucx.source_code.linters.directfs import ( DIRECT_FS_ACCESS_PATTERNS, DirectFsAccessPyLinter, DirectFsAccessSqlLinter, ) +from databricks.labs.ucx.source_code.notebooks.cells import CellLanguage @pytest.mark.parametrize( @@ -145,3 +152,18 @@ def test_dfsa_queries_failure(query: str) -> None: end_col=1024, ), ] + + +class _TestCollectorWalker(DfsaCollectorWalker): + # inherit from DfsaCollectorWalker because it's public + + def collect_from_source(self, language: CellLanguage) -> Iterable[DirectFsAccess]: + return self._collect_from_source("empty", language, Path(""), None) + + +@pytest.mark.parametrize("language", list(iter(CellLanguage))) +def test_collector_supports_all_cell_languages(language, mock_path_lookup, migration_index): + graph = create_autospec(DependencyGraph) + graph.assert_not_called() + collector = _TestCollectorWalker(graph, set(), mock_path_lookup, CurrentSessionState(), migration_index) + list(collector.collect_from_source(language))