From f0bc926dfcc1798bc2e905e270f5496ae0690298 Mon Sep 17 00:00:00 2001 From: Matthijs Hollemans Date: Tue, 30 May 2023 15:06:58 +0200 Subject: [PATCH] fix Whisper tests on GPU (#23753) * move input features to GPU * skip these tests because undefined behavior * unskip tests --- tests/models/whisper/test_modeling_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2be7f6884e72..3eee5ad4967c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1477,7 +1477,7 @@ def test_generate_with_prompt_ids(self): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model.to(torch_device) input_speech = self._load_datasamples(4)[-1:] - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) output_without_prompt = model.generate(input_features) prompt_ids = processor.get_prompt_ids("Leighton") @@ -1494,7 +1494,7 @@ def test_generate_with_prompt_ids_and_forced_decoder_ids(self): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) task = "translate" language = "de" expected_tokens = [f"<|{task}|>", f"<|{language}|>"] @@ -1513,7 +1513,7 @@ def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self): model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") model.to(torch_device) input_speech = self._load_datasamples(1) - input_features = processor(input_speech, return_tensors="pt").input_features + input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) prompt = "test prompt" prompt_ids = processor.get_prompt_ids(prompt)