Skip to content

Commit

Permalink
Added functionality to support parsing imports within if/else stateme…
Browse files Browse the repository at this point in the history
…nts (#23)
  • Loading branch information
Florian Maas authored Sep 4, 2022
1 parent 8c0a2b2 commit d376a2e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
35 changes: 28 additions & 7 deletions deptry/import_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import logging
from pathlib import Path
from typing import List
from typing import List, Union

from deptry.notebook_import_extractor import NotebookImportExtractor

Expand All @@ -15,13 +15,13 @@ def __init__(self) -> None:
pass

def get_imported_modules_for_list_of_files(self, list_of_files: List[Path]) -> List[str]:
modules_per_file = [self._get_imported_modules_from_file(file) for file in list_of_files]
modules_per_file = [self.get_imported_modules_from_file(file) for file in list_of_files]
all_modules = self._flatten_list(modules_per_file)
unique_modules = sorted(list(set(all_modules)))
logging.debug(f"All imported modules: {unique_modules}\n")
return unique_modules

def _get_imported_modules_from_file(self, path_to_file: Path) -> List[str]:
def get_imported_modules_from_file(self, path_to_file: Path) -> List[str]:
try:
if str(path_to_file).endswith(".ipynb"):
modules = self._get_imported_modules_from_ipynb(path_to_file)
Expand All @@ -33,20 +33,41 @@ def _get_imported_modules_from_file(self, path_to_file: Path) -> List[str]:
raise (e)
return modules

def get_imported_modules_from_str(self, file_str: str) -> List[str]:
root = ast.parse(file_str)
import_nodes = self._get_import_nodes_from(root)
return self._get_import_modules_from(import_nodes)

def _get_imported_modules_from_py(self, path_to_py_file: Path) -> List[str]:
with open(path_to_py_file) as f:
root = ast.parse(f.read(), path_to_py_file) # type: ignore
return self._get_modules_from_ast_root(root)
import_nodes = self._get_import_nodes_from(root)
return self._get_import_modules_from(import_nodes)

def _get_imported_modules_from_ipynb(self, path_to_ipynb_file: Path) -> List[str]:
imports = NotebookImportExtractor().extract(path_to_ipynb_file)
root = ast.parse("\n".join(imports))
return self._get_modules_from_ast_root(root)
import_nodes = self._get_import_nodes_from(root)
return self._get_import_modules_from(import_nodes)

def _get_import_nodes_from(self, root: Union[ast.Module, ast.If]):
"""
Recursively collect import nodes from a Python module. This is needed to find imports that
are defined within if/else statements. In that case, the ast.Import or ast.ImportFrom node
is a child of an ast.If node.
"""
imports = []
for node in ast.iter_child_nodes(root):
if isinstance(node, ast.If):
imports += self._get_import_nodes_from(node)
elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
imports += [node]
return imports

@staticmethod
def _get_modules_from_ast_root(root: ast.Module) -> List[str]:
def _get_import_modules_from(nodes: List[Union[ast.Import, ast.ImportFrom]]) -> List[str]:
modules = []
for node in ast.iter_child_nodes(root):
for node in nodes:
if isinstance(node, ast.Import):
modules += [x.name.split(".")[0] for x in node.names]
elif isinstance(node, ast.ImportFrom):
Expand Down
20 changes: 18 additions & 2 deletions tests/test_import_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,28 @@


def test_import_parser_py():
imported_modules = ImportParser()._get_imported_modules_from_file(Path("tests/data/some_imports.py"))
imported_modules = ImportParser().get_imported_modules_from_file(Path("tests/data/some_imports.py"))
assert set(imported_modules) == set(["os", "pathlib", "typing", "pandas", "numpy"])


def test_import_parser_ipynb():
imported_modules = ImportParser()._get_imported_modules_from_file(
imported_modules = ImportParser().get_imported_modules_from_file(
Path("tests/data/projects/project_with_obsolete/src/notebook.ipynb")
)
assert set(imported_modules) == set(["click", "pandas", "numpy", "cookiecutter_poetry"])


def test_import_parser_ifelse():
imported_modules = ImportParser().get_imported_modules_from_str(
"""
x=1
import numpy
if x>0:
import pandas
elif x<0:
from typing import List
else:
import logging
"""
)
assert set(imported_modules) == set(["numpy", "pandas", "typing", "logging"])

0 comments on commit d376a2e

Please sign in to comment.