Skip to content

Commit

Permalink
[Modular] skip modular checks based on diff (#36130)
Browse files Browse the repository at this point in the history
skip modular checks based on diff
  • Loading branch information
gante authored Feb 13, 2025
1 parent 6397916 commit d114a6f
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- run:
name: "Prepare pipeline parameters"
command: |
python utils/process_test_artifacts.py
python utils/process_test_artifacts.py
# To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters.
# Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation.
Expand Down Expand Up @@ -110,7 +110,7 @@ jobs:
- run:
name: "Prepare pipeline parameters"
command: |
python utils/process_test_artifacts.py
python utils/process_test_artifacts.py
# To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters.
# Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation.
Expand Down
2 changes: 1 addition & 1 deletion tests/repo_utils/modular/test_conversion_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool:
class ConversionOrderTest(unittest.TestCase):
def test_conversion_order(self):
# Find the order
priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
# Extract just the model names
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]

Expand Down
34 changes: 0 additions & 34 deletions utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,40 +1024,6 @@ def _rep(match):
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"


def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
"""
Find the text in a file between two prompts.
Args:
filename (`str`): The name of the file to look into.
start_prompt (`str`): The string to look for that introduces the content looked for.
end_prompt (`str`): The string to look for that ends the content looked for.
Returns:
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
original file, the index of the end line in the original file and the list of lines of that file.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1

end_index = start_index
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1

while len(lines[start_index]) <= 1:
start_index += 1
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines


# Map a model name with the name it has in the README for the check_readme check
SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation",
Expand Down
76 changes: 75 additions & 1 deletion utils/check_modular_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import difflib
import glob
import logging
import subprocess
from io import StringIO

from create_dependency_mapping import find_priority_list
Expand Down Expand Up @@ -61,6 +62,56 @@ def compare_files(modular_file_path, fix_and_overwrite=False):
return diff


def get_models_in_diff():
"""
Finds all models that have been modified in the diff.
Returns:
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
"""
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
.decode("utf-8")
.split()
)

# Matches both modelling files and tests
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
model_names = set()
for file_path in relevant_modified_files:
model_name = file_path.split("/")[-2]
model_names.add(model_name)
return model_names


def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
"""
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
Model is in the diff -> not guaranteed to have no differences
Dependency is in the diff -> not guaranteed to have no differences
Otherwise -> guaranteed to have no differences
Args:
modular_file_path: The path to the modular file.
dependencies: A dictionary containing the dependencies of each modular file.
models_in_diff: A set containing the names of the models that have been modified.
Returns:
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
"""
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
if model_name in models_in_diff:
return False
for dep in dependencies[modular_file_path]:
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
dependency_model_name = dep.split(".")[-2]
if dependency_model_name in models_in_diff:
return False
return True


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
Expand All @@ -72,9 +123,32 @@ def compare_files(modular_file_path, fix_and_overwrite=False):
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)

# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
models_in_diff = get_models_in_diff()
if not models_in_diff:
console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]")
exit(0)

skipped_models = set()
non_matching_files = 0
for modular_file_path in find_priority_list(args.files):
ordered_files, dependencies = find_priority_list(args.files)
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
models_in_diff = get_models_in_diff() # When overwriting, the diff changes

if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")

if skipped_models:
console.print(
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
f"{', '.join(skipped_models)}[/bold green]"
)
14 changes: 12 additions & 2 deletions utils/create_dependency_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def map_dependencies(py_files):


def find_priority_list(py_files):
"""
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
models will be higher in the topological order.
Args:
py_files: List of paths to the modular files
Returns:
A tuple with the ordered files (list) and their dependencies (dict)
"""
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
return ordered_classes
ordered_files = topological_sort(dependencies)
return ordered_files, dependencies
2 changes: 1 addition & 1 deletion utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,7 @@ def save_modeling_file(modular_file, converted_file):
if args.files_to_parse == ["examples"]:
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)

priority_list = find_priority_list(args.files_to_parse)
priority_list, _ = find_priority_list(args.files_to_parse)
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"

for file_name in priority_list:
Expand Down

0 comments on commit d114a6f

Please sign in to comment.