From 663b91ae6ebc4048bd7006f6766bff30bec2dee7 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 24 Oct 2024 13:03:47 +0200 Subject: [PATCH 1/2] add test --- tests/unit/source_code/python/test_python_ast.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index c23b76c4ef..e65abe4b8e 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -170,6 +170,18 @@ def test_is_from_module() -> None: assert Tree(save_call).is_from_module("spark") +def test_locates_member_import() -> None: + source = """ +from importlib import import_module +module = import_module("xyz") +""" + maybe_tree = Tree.maybe_normalized_parse(source) + assert maybe_tree.tree is not None, maybe_tree.failure + tree = maybe_tree.tree + import_calls = tree.locate(Call, [("import_module", Attribute), ("importlib", Name)]) + assert import_calls + + @pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")]) def test_is_instance_of(source, name, class_name) -> None: maybe_tree = Tree.maybe_normalized_parse(source) From 8456ac393116b2e6a8ec01cfc379f8ac92c5dfb4 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Thu, 24 Oct 2024 13:05:05 +0200 Subject: [PATCH 2/2] resolve global name to import member --- .../labs/ucx/source_code/python/python_ast.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index be860d8fed..72c315e768 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -68,7 +68,7 @@ def first_statement(self) -> NodeNG | None: return self.tree.first_statement() -class Tree: +class Tree: # pylint: disable=too-many-public-methods @classmethod def maybe_parse(cls, code: str) -> MaybeTree: @@ -285,6 +285,11 @@ def has_global(self, name: str) -> bool: self_module: Module = cast(Module, self.node) return self_module.globals.get(name, None) is not None + def get_global(self, name: str) -> list[NodeNG]: + if not self.has_global(name): + return [] + return cast(Module, self.node).globals.get(name) + def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]: if not isinstance(self.node, Module): raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}") @@ -486,6 +491,7 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]): self._matched_nodes: list[NodeNG] = [] self._node_type = node_type self._match_nodes = match_nodes + self._imports: dict[str, list[NodeNG]] = {} @property def matched_nodes(self) -> list[NodeNG]: @@ -521,6 +527,7 @@ def _matches(self, node: NodeNG, depth: int) -> bool: if isinstance(node, Call): return self._matches(node.func, depth) name, match_node = self._match_nodes[depth] + node = self._adjust_node_for_import_member(name, match_node, node) if not isinstance(node, match_node): return False next_node: NodeNG | None = None @@ -538,6 +545,39 @@ def _matches(self, node: NodeNG, depth: int) -> bool: return len(self._match_nodes) - 1 == depth return self._matches(next_node, depth + 1) + def _adjust_node_for_import_member(self, name: str, match_node: type, node: NodeNG) -> NodeNG: + if isinstance(node, match_node): + return node + # if we're looking for an attribute, it might be a global name + if match_node != Attribute or not isinstance(node, Name) or node.name != name: + return node + # in which case it could be an import member + module = Tree(Tree(node).root) + if not module.has_global(node.name): + return node + for import_from in module.get_global(node.name): + if not isinstance(import_from, ImportFrom): + continue + parent = Name( + name=import_from.modname, + lineno=import_from.lineno, + col_offset=import_from.col_offset, + end_lineno=import_from.end_lineno, + end_col_offset=import_from.end_col_offset, + parent=import_from.parent, + ) + resolved = Attribute( + attrname=name, + lineno=import_from.lineno, + col_offset=import_from.col_offset, + end_lineno=import_from.end_lineno, + end_col_offset=import_from.end_col_offset, + parent=parent, + ) + resolved.postinit(parent) + return resolved + return node + class NodeBase(ABC):