Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix import member issue #3056

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/databricks/labs/ucx/source_code/python/python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def first_statement(self) -> NodeNG | None:
return self.tree.first_statement()


class Tree:
class Tree: # pylint: disable=too-many-public-methods

@classmethod
def maybe_parse(cls, code: str) -> MaybeTree:
Expand Down Expand Up @@ -285,6 +285,11 @@ def has_global(self, name: str) -> bool:
self_module: Module = cast(Module, self.node)
return self_module.globals.get(name, None) is not None

def get_global(self, name: str) -> list[NodeNG]:
if not self.has_global(name):
return []
return cast(Module, self.node).globals.get(name)

def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]:
if not isinstance(self.node, Module):
raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}")
Expand Down Expand Up @@ -486,6 +491,7 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]):
self._matched_nodes: list[NodeNG] = []
self._node_type = node_type
self._match_nodes = match_nodes
self._imports: dict[str, list[NodeNG]] = {}

@property
def matched_nodes(self) -> list[NodeNG]:
Expand Down Expand Up @@ -521,6 +527,7 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
if isinstance(node, Call):
return self._matches(node.func, depth)
name, match_node = self._match_nodes[depth]
node = self._adjust_node_for_import_member(name, match_node, node)
if not isinstance(node, match_node):
return False
next_node: NodeNG | None = None
Expand All @@ -538,6 +545,39 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
return len(self._match_nodes) - 1 == depth
return self._matches(next_node, depth + 1)

def _adjust_node_for_import_member(self, name: str, match_node: type, node: NodeNG) -> NodeNG:
if isinstance(node, match_node):
return node
# if we're looking for an attribute, it might be a global name
if match_node != Attribute or not isinstance(node, Name) or node.name != name:
return node
# in which case it could be an import member
module = Tree(Tree(node).root)
if not module.has_global(node.name):
return node
for import_from in module.get_global(node.name):
if not isinstance(import_from, ImportFrom):
continue
parent = Name(
name=import_from.modname,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=import_from.parent,
)
resolved = Attribute(
attrname=name,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=parent,
)
resolved.postinit(parent)
return resolved
return node


class NodeBase(ABC):

Expand Down
12 changes: 12 additions & 0 deletions tests/unit/source_code/python/test_python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ def test_is_from_module() -> None:
assert Tree(save_call).is_from_module("spark")


def test_locates_member_import() -> None:
source = """
from importlib import import_module
module = import_module("xyz")
"""
maybe_tree = Tree.maybe_normalized_parse(source)
assert maybe_tree.tree is not None, maybe_tree.failure
tree = maybe_tree.tree
import_calls = tree.locate(Call, [("import_module", Attribute), ("importlib", Name)])
assert import_calls


@pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")])
def test_is_instance_of(source, name, class_name) -> None:
maybe_tree = Tree.maybe_normalized_parse(source)
Expand Down
Loading