diff --git a/.circleci/config.yml b/.circleci/config.yml index ecd7066931a9..dbbebe9fc065 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -58,7 +58,7 @@ jobs: - run: name: "Prepare pipeline parameters" command: | - python utils/process_test_artifacts.py + python utils/process_test_artifacts.py # To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters. # Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation. @@ -110,7 +110,7 @@ jobs: - run: name: "Prepare pipeline parameters" command: | - python utils/process_test_artifacts.py + python utils/process_test_artifacts.py # To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters. # Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation. diff --git a/tests/repo_utils/modular/test_conversion_order.py b/tests/repo_utils/modular/test_conversion_order.py index f5e133ce1fea..643b3714a763 100644 --- a/tests/repo_utils/modular/test_conversion_order.py +++ b/tests/repo_utils/modular/test_conversion_order.py @@ -48,7 +48,7 @@ def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool: class ConversionOrderTest(unittest.TestCase): def test_conversion_order(self): # Find the order - priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE) + priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE) # Extract just the model names model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list] diff --git a/utils/check_copies.py b/utils/check_copies.py index daf90b4d4a2c..c62a192c1075 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -1024,40 +1024,6 @@ def _rep(match): return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n" -def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]: - """ - Find the text in a file between two prompts. - - Args: - filename (`str`): The name of the file to look into. - start_prompt (`str`): The string to look for that introduces the content looked for. - end_prompt (`str`): The string to look for that ends the content looked for. - - Returns: - Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the - original file, the index of the end line in the original file and the list of lines of that file. - """ - with open(filename, "r", encoding="utf-8", newline="\n") as f: - lines = f.readlines() - # Find the start prompt. - start_index = 0 - while not lines[start_index].startswith(start_prompt): - start_index += 1 - start_index += 1 - - end_index = start_index - while not lines[end_index].startswith(end_prompt): - end_index += 1 - end_index -= 1 - - while len(lines[start_index]) <= 1: - start_index += 1 - while len(lines[end_index]) <= 1: - end_index -= 1 - end_index += 1 - return "".join(lines[start_index:end_index]), start_index, end_index, lines - - # Map a model name with the name it has in the README for the check_readme check SPECIAL_MODEL_NAMES = { "Bert Generation": "BERT For Sequence Generation", diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index 5946d6ef1687..e08621b5c32c 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -2,6 +2,7 @@ import difflib import glob import logging +import subprocess from io import StringIO from create_dependency_mapping import find_priority_list @@ -61,6 +62,56 @@ def compare_files(modular_file_path, fix_and_overwrite=False): return diff +def get_models_in_diff(): + """ + Finds all models that have been modified in the diff. + + Returns: + A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}). + """ + fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") + modified_files = ( + subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split()) + .decode("utf-8") + .split() + ) + + # Matches both modelling files and tests + relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")] + model_names = set() + for file_path in relevant_modified_files: + model_name = file_path.split("/")[-2] + model_names.add(model_name) + return model_names + + +def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff): + """ + Returns whether it is guaranteed to have no differences between the modular file and the modeling file. + + Model is in the diff -> not guaranteed to have no differences + Dependency is in the diff -> not guaranteed to have no differences + Otherwise -> guaranteed to have no differences + + Args: + modular_file_path: The path to the modular file. + dependencies: A dictionary containing the dependencies of each modular file. + models_in_diff: A set containing the names of the models that have been modified. + + Returns: + A boolean indicating whether the model (code and tests) is guaranteed to have no differences. + """ + model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "") + if model_name in models_in_diff: + return False + for dep in dependencies[modular_file_path]: + # two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)` + dependency_model_name = dep.split(".")[-2] + if dependency_model_name in models_in_diff: + return False + return True + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.") parser.add_argument( @@ -72,9 +123,32 @@ def compare_files(modular_file_path, fix_and_overwrite=False): args = parser.parse_args() if args.files == ["all"]: args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) + + # Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies + # are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this + # script will do nothing. + models_in_diff = get_models_in_diff() + if not models_in_diff: + console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]") + exit(0) + + skipped_models = set() non_matching_files = 0 - for modular_file_path in find_priority_list(args.files): + ordered_files, dependencies = find_priority_list(args.files) + for modular_file_path in ordered_files: + is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff) + if is_guaranteed_no_diff: + model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "") + skipped_models.add(model_name) + continue non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) + models_in_diff = get_models_in_diff() # When overwriting, the diff changes if non_matching_files and not args.fix_and_overwrite: raise ValueError("Some diff and their modeling code did not match.") + + if skipped_models: + console.print( + f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: " + f"{', '.join(skipped_models)}[/bold green]" + ) diff --git a/utils/create_dependency_mapping.py b/utils/create_dependency_mapping.py index 5cf38cdd1f81..f0f62cf1b000 100644 --- a/utils/create_dependency_mapping.py +++ b/utils/create_dependency_mapping.py @@ -55,6 +55,16 @@ def map_dependencies(py_files): def find_priority_list(py_files): + """ + Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular + models will be higher in the topological order. + + Args: + py_files: List of paths to the modular files + + Returns: + A tuple with the ordered files (list) and their dependencies (dict) + """ dependencies = map_dependencies(py_files) - ordered_classes = topological_sort(dependencies) - return ordered_classes + ordered_files = topological_sort(dependencies) + return ordered_files, dependencies diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index bb7799b4682a..3c5d062fbe3a 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -1716,7 +1716,7 @@ def save_modeling_file(modular_file, converted_file): if args.files_to_parse == ["examples"]: args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) - priority_list = find_priority_list(args.files_to_parse) + priority_list, _ = find_priority_list(args.files_to_parse) assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted" for file_name in priority_list: