Skip to content

Commit

Permalink
Fix import member issue (#3056)
Browse files Browse the repository at this point in the history
## Changes
Our current implementation does not properly collect imports when source
code uses indirect import patterns such as:
```
from importlib import import_module
module = import_module("some-module")

```
This PR fixes that

### Linked issues
None

### Functionality
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Oct 24, 2024
1 parent f2c2f81 commit 1c8c3e6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/databricks/labs/ucx/source_code/python/python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def first_statement(self) -> NodeNG | None:
return self.tree.first_statement()


class Tree:
class Tree: # pylint: disable=too-many-public-methods

@classmethod
def maybe_parse(cls, code: str) -> MaybeTree:
Expand Down Expand Up @@ -285,6 +285,11 @@ def has_global(self, name: str) -> bool:
self_module: Module = cast(Module, self.node)
return self_module.globals.get(name, None) is not None

def get_global(self, name: str) -> list[NodeNG]:
if not self.has_global(name):
return []
return cast(Module, self.node).globals.get(name)

def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]:
if not isinstance(self.node, Module):
raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}")
Expand Down Expand Up @@ -486,6 +491,7 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]):
self._matched_nodes: list[NodeNG] = []
self._node_type = node_type
self._match_nodes = match_nodes
self._imports: dict[str, list[NodeNG]] = {}

@property
def matched_nodes(self) -> list[NodeNG]:
Expand Down Expand Up @@ -521,6 +527,7 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
if isinstance(node, Call):
return self._matches(node.func, depth)
name, match_node = self._match_nodes[depth]
node = self._adjust_node_for_import_member(name, match_node, node)
if not isinstance(node, match_node):
return False
next_node: NodeNG | None = None
Expand All @@ -538,6 +545,39 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
return len(self._match_nodes) - 1 == depth
return self._matches(next_node, depth + 1)

def _adjust_node_for_import_member(self, name: str, match_node: type, node: NodeNG) -> NodeNG:
if isinstance(node, match_node):
return node
# if we're looking for an attribute, it might be a global name
if match_node != Attribute or not isinstance(node, Name) or node.name != name:
return node
# in which case it could be an import member
module = Tree(Tree(node).root)
if not module.has_global(node.name):
return node
for import_from in module.get_global(node.name):
if not isinstance(import_from, ImportFrom):
continue
parent = Name(
name=import_from.modname,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=import_from.parent,
)
resolved = Attribute(
attrname=name,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=parent,
)
resolved.postinit(parent)
return resolved
return node


class NodeBase(ABC):

Expand Down
12 changes: 12 additions & 0 deletions tests/unit/source_code/python/test_python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ def test_is_from_module() -> None:
assert Tree(save_call).is_from_module("spark")


def test_locates_member_import() -> None:
source = """
from importlib import import_module
module = import_module("xyz")
"""
maybe_tree = Tree.maybe_normalized_parse(source)
assert maybe_tree.tree is not None, maybe_tree.failure
tree = maybe_tree.tree
import_calls = tree.locate(Call, [("import_module", Attribute), ("importlib", Name)])
assert import_calls


@pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")])
def test_is_instance_of(source, name, class_name) -> None:
maybe_tree = Tree.maybe_normalized_parse(source)
Expand Down

0 comments on commit 1c8c3e6

Please sign in to comment.