Skip to content

Commit a42d7fe

Browse files
committed
fea(): More cleanups on mixtral pytest script
1 parent 6a3e9a6 commit a42d7fe

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

tests/transformers/tests/models/mixtral/test_modeling_mixtral.py

+32-23
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,30 @@ def test_greedy_generate_dict_outputs_use_cache(self):
564564
def test_retain_grad_hidden_states_attentions(self):
565565
pass
566566

567+
@unittest.skip(reason="This test is not supported for Mixtral")
568+
def test_generate_from_inputs_embeds_decoder_only(self):
569+
pass
570+
571+
@unittest.skip(reason="This test is not supported for Mixtral")
572+
def test_assisted_decoding_sample(self):
573+
pass
574+
567575
@unittest.skip(reason="This test is not supported for Mixtral")
568576
def test_sample_generate_dict_output(self):
569577
pass
570578

579+
@unittest.skip(reason="Mixtral buffers include complex numbers, which breaks this test")
580+
def test_save_load_fast_init_from_base(self):
581+
pass
582+
583+
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
584+
def test_past_key_values_format(self):
585+
pass
586+
587+
@unittest.skip(reason="NotImplemented reorder_cache` function is not correctly implemented")
588+
def test_constrained_beam_search_generate(self):
589+
pass
590+
571591
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
572592
def is_pipeline_test_to_skip(
573593
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
@@ -599,17 +619,18 @@ def test_model_various_embeddings(self):
599619
self.model_tester.create_and_check_model(*config_and_inputs)
600620

601621
def test_Mixtral_sequence_classification_model(self):
602-
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
603-
print(config)
604-
config.num_labels = 3
605-
input_ids = input_dict["input_ids"]
606-
attention_mask = input_ids.ne(1).to(torch_device)
607-
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
608-
model = MixtralForSequenceClassification(config)
609-
model.to(torch_device)
610-
model.eval()
611-
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
612-
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
622+
with torch.inference_mode():
623+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
624+
print(config)
625+
config.num_labels = 3
626+
input_ids = input_dict["input_ids"]
627+
attention_mask = input_ids.ne(1).to(torch_device)
628+
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
629+
model = MixtralForSequenceClassification(config)
630+
model.to(torch_device)
631+
model.eval()
632+
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
633+
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
613634

614635
def test_Mixtral_sequence_classification_model_for_single_label(self):
615636
# Starting 1.20, we added torch.inference_mode context manager here.
@@ -659,18 +680,6 @@ def test_Mixtral_token_classification_model(self):
659680
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
660681
)
661682

662-
@unittest.skip(reason="Mixtral buffers include complex numbers, which breaks this test")
663-
def test_save_load_fast_init_from_base(self):
664-
pass
665-
666-
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
667-
def test_past_key_values_format(self):
668-
pass
669-
670-
@unittest.skip(reason="NotImplemented reorder_cache` function is not correctly implemented")
671-
def test_constrained_beam_search_generate(self):
672-
pass
673-
674683
@require_flash_attn
675684
@require_torch_gpu
676685
@pytest.mark.flash_attn_test

0 commit comments

Comments
 (0)