From e8179a1ba2ebd581cbd9c9f054b636549d97d080 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 24 Oct 2024 18:37:11 +0200 Subject: [PATCH] Fix dynamic import issue (#3053) ## Changes Our current implementation doesn't infer import names when calling `importlib.import_module(some_name)` This PR fixes that. ### Linked issues None ### Functionality None ### Tests - [x] added unit tests --------- Co-authored-by: Eric Vergnaud --- .../labs/ucx/source_code/linters/imports.py | 18 ++++++-- .../labs/ucx/source_code/python/python_ast.py | 42 ++++++++++++++++++- tests/integration/source_code/test_graph.py | 19 +++++++++ .../source_code/python/test_python_ast.py | 12 ++++++ .../unit/source_code/samples/import-module.py | 4 ++ 5 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 tests/unit/source_code/samples/import-module.py diff --git a/src/databricks/labs/ucx/source_code/linters/imports.py b/src/databricks/labs/ucx/source_code/linters/imports.py index 906a24cd19..d44b61b70a 100644 --- a/src/databricks/labs/ucx/source_code/linters/imports.py +++ b/src/databricks/labs/ucx/source_code/linters/imports.py @@ -9,7 +9,6 @@ from astroid import ( # type: ignore Attribute, Call, - Const, InferenceError, Import, ImportFrom, @@ -71,9 +70,20 @@ def _make_sources_for_import_call_nodes( problems: list[T], ) -> Iterable[ImportSource]: for node in nodes: - arg = node.args[0] - if isinstance(arg, Const): - yield ImportSource(node, arg.value) + yield from cls._make_sources_for_import_call_node(node, problem_factory, problems) + + @classmethod + def _make_sources_for_import_call_node( + cls, + node: Call, + problem_factory: ProblemFactory, + problems: list[T], + ) -> Iterable[ImportSource]: + if not node.args: + return + for inferred in InferredValue.infer_from_node(node.args[0]): + if inferred.is_inferred(): + yield ImportSource(node, inferred.as_string()) continue problem = problem_factory( 'dependency-not-constant', "Can't check dependency not provided as a constant", node diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index be860d8fed..72c315e768 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -68,7 +68,7 @@ def first_statement(self) -> NodeNG | None: return self.tree.first_statement() -class Tree: +class Tree: # pylint: disable=too-many-public-methods @classmethod def maybe_parse(cls, code: str) -> MaybeTree: @@ -285,6 +285,11 @@ def has_global(self, name: str) -> bool: self_module: Module = cast(Module, self.node) return self_module.globals.get(name, None) is not None + def get_global(self, name: str) -> list[NodeNG]: + if not self.has_global(name): + return [] + return cast(Module, self.node).globals.get(name) + def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]: if not isinstance(self.node, Module): raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}") @@ -486,6 +491,7 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]): self._matched_nodes: list[NodeNG] = [] self._node_type = node_type self._match_nodes = match_nodes + self._imports: dict[str, list[NodeNG]] = {} @property def matched_nodes(self) -> list[NodeNG]: @@ -521,6 +527,7 @@ def _matches(self, node: NodeNG, depth: int) -> bool: if isinstance(node, Call): return self._matches(node.func, depth) name, match_node = self._match_nodes[depth] + node = self._adjust_node_for_import_member(name, match_node, node) if not isinstance(node, match_node): return False next_node: NodeNG | None = None @@ -538,6 +545,39 @@ def _matches(self, node: NodeNG, depth: int) -> bool: return len(self._match_nodes) - 1 == depth return self._matches(next_node, depth + 1) + def _adjust_node_for_import_member(self, name: str, match_node: type, node: NodeNG) -> NodeNG: + if isinstance(node, match_node): + return node + # if we're looking for an attribute, it might be a global name + if match_node != Attribute or not isinstance(node, Name) or node.name != name: + return node + # in which case it could be an import member + module = Tree(Tree(node).root) + if not module.has_global(node.name): + return node + for import_from in module.get_global(node.name): + if not isinstance(import_from, ImportFrom): + continue + parent = Name( + name=import_from.modname, + lineno=import_from.lineno, + col_offset=import_from.col_offset, + end_lineno=import_from.end_lineno, + end_col_offset=import_from.end_col_offset, + parent=import_from.parent, + ) + resolved = Attribute( + attrname=name, + lineno=import_from.lineno, + col_offset=import_from.col_offset, + end_lineno=import_from.end_lineno, + end_col_offset=import_from.end_col_offset, + parent=parent, + ) + resolved.postinit(parent) + return resolved + return node + class NodeBase(ABC): diff --git a/tests/integration/source_code/test_graph.py b/tests/integration/source_code/test_graph.py index 8fb279a286..b66e14a740 100644 --- a/tests/integration/source_code/test_graph.py +++ b/tests/integration/source_code/test_graph.py @@ -38,3 +38,22 @@ def module_compatibility(self, name: str) -> Compatibility: # visit the graph without a 'visited' set roots = graph.root_dependencies assert roots + + +def test_graph_imports_dynamic_import(): + allow_list = KnownList() + library_resolver = PythonLibraryResolver(allow_list) + notebook_resolver = NotebookResolver(NotebookLoader()) + import_resolver = ImportFileResolver(FileLoader(), allow_list) + path_lookup = PathLookup.from_sys_path(Path(__file__).parent) + dependency_resolver = DependencyResolver( + library_resolver, notebook_resolver, import_resolver, import_resolver, path_lookup + ) + root_path = Path(__file__).parent.parent.parent / "unit" / "source_code" / "samples" / "import-module.py" + assert root_path.is_file() + maybe = dependency_resolver.resolve_file(path_lookup, root_path) + assert maybe.dependency + graph = DependencyGraph(maybe.dependency, None, dependency_resolver, path_lookup, CurrentSessionState()) + container = maybe.dependency.load(path_lookup) + problems = container.build_dependency_graph(graph) + assert not problems diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index c23b76c4ef..e65abe4b8e 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -170,6 +170,18 @@ def test_is_from_module() -> None: assert Tree(save_call).is_from_module("spark") +def test_locates_member_import() -> None: + source = """ +from importlib import import_module +module = import_module("xyz") +""" + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree + import_calls = tree.locate(Call, [("import_module", Attribute), ("importlib", Name)]) + assert import_calls + + @pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")]) def test_is_instance_of(source, name, class_name) -> None: maybe_tree = Tree.maybe_normalized_parse(source) diff --git a/tests/unit/source_code/samples/import-module.py b/tests/unit/source_code/samples/import-module.py new file mode 100644 index 0000000000..6e05a636ce --- /dev/null +++ b/tests/unit/source_code/samples/import-module.py @@ -0,0 +1,4 @@ +import importlib + +module_name = "astroid" +module = importlib.import_module(module_name)