Skip to content

Commit

Permalink
Fix test fetcher (#36129)
Browse files Browse the repository at this point in the history
* fix

* fix

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Feb 12, 2025
1 parent 1fae54c commit 4a5a7b9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,11 +1927,16 @@ def fetch__all__(file_content):
if "__all__" not in file_content:
return []

start_index = None
lines = file_content.splitlines()
for index, line in enumerate(lines):
if line.startswith("__all__"):
start_index = index

# There is no line starting with `__all__`
if start_index is None:
return []

lines = lines[start_index:]

if not lines[0].startswith("__all__"):
Expand Down
14 changes: 14 additions & 0 deletions utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,11 +735,23 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
if module.endswith("__init__.py"):
# So we get the imports from that init then try to find where our objects come from.
new_imported_modules = extract_imports(module, cache=cache)

# Add imports via `define_import_structure` after the #35167 as we remove explicit import in `__init__.py`
from transformers.utils.import_utils import define_import_structure

new_imported_modules_2 = define_import_structure(PATH_TO_REPO / module)

for mapping in new_imported_modules_2.values():
for _module, _imports in mapping.items():
_module = module.replace("__init__.py", f"{_module}.py")
new_imported_modules.append((_module, list(_imports)))

for new_module, new_imports in new_imported_modules:
if any(i in new_imports for i in imports):
if new_module not in dependencies:
new_modules.append((new_module, [i for i in new_imports if i in imports]))
imports = [i for i in imports if i not in new_imports]

if len(imports) > 0:
# If there are any objects lefts, they may be a submodule
path_to_module = PATH_TO_REPO / module.replace("__init__.py", "")
Expand All @@ -759,6 +771,7 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
dependencies.append(module)

imported_modules = new_modules

return dependencies


Expand Down Expand Up @@ -880,6 +893,7 @@ def create_reverse_dependency_map() -> Dict[str, List[str]]:
depending on it recursively. This way the tests impacted by a change in file A are the test files in the list
corresponding to key A in this result.
"""

cache = {}
# Start from the example deps init.
example_deps, examples = init_test_examples_dependencies()
Expand Down

0 comments on commit 4a5a7b9

Please sign in to comment.