From da79b18087433506ad32ad62e8c3a68892410afe Mon Sep 17 00:00:00 2001 From: Sai-Suraj-27 Date: Wed, 10 Jul 2024 18:16:31 +0530 Subject: [PATCH] fix: Removed `duplicate` field definitions in some classes (#31888) Removed duplicate field definitions in classes. --- examples/flax/language-modeling/run_clm_flax.py | 3 --- .../jax-projects/hybrid_clip/run_hybrid_clip.py | 3 --- .../jax-projects/model_parallel/run_clm_mp.py | 3 --- .../models/deformable_detr/modeling_deformable_detr.py | 1 - src/transformers/models/video_llava/modeling_video_llava.py | 1 - tests/models/fnet/test_modeling_fnet.py | 1 - tests/models/mamba/test_modeling_mamba.py | 1 - tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py | 1 - 8 files changed, 14 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index a12bc6d3c8d5..c486aae71f62 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -225,9 +225,6 @@ class DataTrainingArguments: ) }, ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) validation_split_percentage: Optional[int] = field( default=5, metadata={ diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index f954f70ee48b..2020f0a35c40 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -163,9 +163,6 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index a72e5cff861c..067f7cb2b185 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -156,9 +156,6 @@ class DataTrainingArguments: ) }, ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index cfa08e3974b7..03648d33b9ad 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1080,7 +1080,6 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] - supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index b540576d4866..9be43151f3c2 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -126,7 +126,6 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _no_split_modules = ["VideoLlavaVisionAttention"] def _init_weights(self, module): std = ( diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index b7acf3610c08..826bf4857110 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -295,7 +295,6 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): # Skip Tests test_pruning = False test_head_masking = False - test_pruning = False # TODO: Fix the failed tests def is_pipeline_test_to_skip( diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 1ddb8ad700b9..4220fabd40b6 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -258,7 +258,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_model_parallel = False test_pruning = False test_head_masking = False # Mamba does not have attention heads - test_model_parallel = False pipeline_model_mapping = ( {"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} ) diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 5a391ee60417..ad542db2733b 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -298,7 +298,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT test_model_parallel = False test_pruning = False test_head_masking = False # RecurrentGemma does not have attention heads - test_model_parallel = False # Need to remove 0.9 in `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer