Skip to content

Commit

Permalink
Fix dynamic import issue (#3053)
Browse files Browse the repository at this point in the history
## Changes
Our current implementation doesn't infer import names when calling
`importlib.import_module(some_name)`
This PR fixes that.

### Linked issues
None

### Functionality
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Oct 24, 2024
1 parent 8178d8a commit e8179a1
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 5 deletions.
18 changes: 14 additions & 4 deletions src/databricks/labs/ucx/source_code/linters/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from astroid import ( # type: ignore
Attribute,
Call,
Const,
InferenceError,
Import,
ImportFrom,
Expand Down Expand Up @@ -71,9 +70,20 @@ def _make_sources_for_import_call_nodes(
problems: list[T],
) -> Iterable[ImportSource]:
for node in nodes:
arg = node.args[0]
if isinstance(arg, Const):
yield ImportSource(node, arg.value)
yield from cls._make_sources_for_import_call_node(node, problem_factory, problems)

@classmethod
def _make_sources_for_import_call_node(
cls,
node: Call,
problem_factory: ProblemFactory,
problems: list[T],
) -> Iterable[ImportSource]:
if not node.args:
return
for inferred in InferredValue.infer_from_node(node.args[0]):
if inferred.is_inferred():
yield ImportSource(node, inferred.as_string())
continue
problem = problem_factory(
'dependency-not-constant', "Can't check dependency not provided as a constant", node
Expand Down
42 changes: 41 additions & 1 deletion src/databricks/labs/ucx/source_code/python/python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def first_statement(self) -> NodeNG | None:
return self.tree.first_statement()


class Tree:
class Tree: # pylint: disable=too-many-public-methods

@classmethod
def maybe_parse(cls, code: str) -> MaybeTree:
Expand Down Expand Up @@ -285,6 +285,11 @@ def has_global(self, name: str) -> bool:
self_module: Module = cast(Module, self.node)
return self_module.globals.get(name, None) is not None

def get_global(self, name: str) -> list[NodeNG]:
if not self.has_global(name):
return []
return cast(Module, self.node).globals.get(name)

def nodes_between(self, first_line: int, last_line: int) -> list[NodeNG]:
if not isinstance(self.node, Module):
raise NotImplementedError(f"Can't extract nodes from {type(self.node).__name__}")
Expand Down Expand Up @@ -486,6 +491,7 @@ def __init__(self, node_type: type, match_nodes: list[tuple[str, type]]):
self._matched_nodes: list[NodeNG] = []
self._node_type = node_type
self._match_nodes = match_nodes
self._imports: dict[str, list[NodeNG]] = {}

@property
def matched_nodes(self) -> list[NodeNG]:
Expand Down Expand Up @@ -521,6 +527,7 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
if isinstance(node, Call):
return self._matches(node.func, depth)
name, match_node = self._match_nodes[depth]
node = self._adjust_node_for_import_member(name, match_node, node)
if not isinstance(node, match_node):
return False
next_node: NodeNG | None = None
Expand All @@ -538,6 +545,39 @@ def _matches(self, node: NodeNG, depth: int) -> bool:
return len(self._match_nodes) - 1 == depth
return self._matches(next_node, depth + 1)

def _adjust_node_for_import_member(self, name: str, match_node: type, node: NodeNG) -> NodeNG:
if isinstance(node, match_node):
return node
# if we're looking for an attribute, it might be a global name
if match_node != Attribute or not isinstance(node, Name) or node.name != name:
return node
# in which case it could be an import member
module = Tree(Tree(node).root)
if not module.has_global(node.name):
return node
for import_from in module.get_global(node.name):
if not isinstance(import_from, ImportFrom):
continue
parent = Name(
name=import_from.modname,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=import_from.parent,
)
resolved = Attribute(
attrname=name,
lineno=import_from.lineno,
col_offset=import_from.col_offset,
end_lineno=import_from.end_lineno,
end_col_offset=import_from.end_col_offset,
parent=parent,
)
resolved.postinit(parent)
return resolved
return node


class NodeBase(ABC):

Expand Down
19 changes: 19 additions & 0 deletions tests/integration/source_code/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,22 @@ def module_compatibility(self, name: str) -> Compatibility:
# visit the graph without a 'visited' set
roots = graph.root_dependencies
assert roots


def test_graph_imports_dynamic_import():
allow_list = KnownList()
library_resolver = PythonLibraryResolver(allow_list)
notebook_resolver = NotebookResolver(NotebookLoader())
import_resolver = ImportFileResolver(FileLoader(), allow_list)
path_lookup = PathLookup.from_sys_path(Path(__file__).parent)
dependency_resolver = DependencyResolver(
library_resolver, notebook_resolver, import_resolver, import_resolver, path_lookup
)
root_path = Path(__file__).parent.parent.parent / "unit" / "source_code" / "samples" / "import-module.py"
assert root_path.is_file()
maybe = dependency_resolver.resolve_file(path_lookup, root_path)
assert maybe.dependency
graph = DependencyGraph(maybe.dependency, None, dependency_resolver, path_lookup, CurrentSessionState())
container = maybe.dependency.load(path_lookup)
problems = container.build_dependency_graph(graph)
assert not problems
12 changes: 12 additions & 0 deletions tests/unit/source_code/python/test_python_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ def test_is_from_module() -> None:
assert Tree(save_call).is_from_module("spark")


def test_locates_member_import() -> None:
source = """
from importlib import import_module
module = import_module("xyz")
"""
maybe_tree = Tree.maybe_normalized_parse(source)
assert maybe_tree.tree is not None, maybe_tree.failure
tree = maybe_tree.tree
import_calls = tree.locate(Call, [("import_module", Attribute), ("importlib", Name)])
assert import_calls


@pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")])
def test_is_instance_of(source, name, class_name) -> None:
maybe_tree = Tree.maybe_normalized_parse(source)
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/source_code/samples/import-module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import importlib

module_name = "astroid"
module = importlib.import_module(module_name)

0 comments on commit e8179a1

Please sign in to comment.