Skip to content

Commit

Permalink
fix: Fixed failing tests in tests/utils/test_add_new_model_like.py (#…
Browse files Browse the repository at this point in the history
…32678)

* Fixed failing tests in tests/utils/test_add_new_model_like.py

* Fixed formatting using ruff.

* Small nit.
  • Loading branch information
Sai-Suraj-27 authored Aug 14, 2024
1 parent a22ff36 commit df32347
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/utils/test_add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
"src/transformers/models/vit/feature_extraction_vit.py",
"src/transformers/models/vit/image_processing_vit.py",
"src/transformers/models/vit/image_processing_vit_fast.py",
"src/transformers/models/vit/modeling_vit.py",
"src/transformers/models/vit/modeling_tf_vit.py",
"src/transformers/models/vit/modeling_flax_vit.py",
Expand Down Expand Up @@ -662,7 +663,13 @@ def test_find_base_model_checkpoint(self):
def test_retrieve_model_classes(self):
gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()}
expected_gpt_classes = {
"pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"},
"pt": {
"GPT2ForTokenClassification",
"GPT2Model",
"GPT2LMHeadModel",
"GPT2ForSequenceClassification",
"GPT2ForQuestionAnswering",
},
"tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"},
"flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"},
}
Expand Down Expand Up @@ -836,7 +843,7 @@ def test_retrieve_info_for_model_with_wav2vec2(self):
]
expected_model_classes = {
"pt": set(wav2vec2_classes),
"tf": {f"TF{m}" for m in wav2vec2_classes[:1]},
"tf": {f"TF{m}" for m in [wav2vec2_classes[0], wav2vec2_classes[-2]]},
"flax": {f"Flax{m}" for m in wav2vec2_classes[:2]},
}

Expand Down Expand Up @@ -870,7 +877,7 @@ def test_retrieve_info_for_model_with_wav2vec2(self):
self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2")
self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2")
self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2")
self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2")
self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV2VEC2")
self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config")
self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor")
self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")
Expand Down

0 comments on commit df32347

Please sign in to comment.