Skip to content

Commit

Permalink
fix low-precision audio classification pipeline (huggingface#35435)
Browse files Browse the repository at this point in the history
* fix low-precision audio classification pipeline

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add test

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torch import

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix torch import

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng authored and elvircrn committed Feb 13, 2025
1 parent fb08877 commit 83d2a94
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/transformers/pipelines/audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def preprocess(self, inputs):
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
return processed

def _forward(self, model_inputs):
Expand Down
37 changes: 36 additions & 1 deletion tests/pipelines/test_pipelines_audio_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import numpy as np
from huggingface_hub import AudioClassificationOutputElement

from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
from transformers import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
is_torch_available,
)
from transformers.pipelines import AudioClassificationPipeline, pipeline
from transformers.testing_utils import (
compare_pipeline_output_to_hub_spec,
Expand All @@ -32,6 +36,10 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch


@is_pipeline_test
class AudioClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
Expand Down Expand Up @@ -127,6 +135,33 @@ def test_small_model_pt(self):
output = audio_classifier(audio_dict, top_k=4)
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])

@require_torch
def test_small_model_pt_fp16(self):
model = "anton-l/wav2vec2-random-tiny-classifier"

audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16)

audio = np.ones((8000,))
output = audio_classifier(audio, top_k=4)

EXPECTED_OUTPUT = [
{"score": 0.0839, "label": "no"},
{"score": 0.0837, "label": "go"},
{"score": 0.0836, "label": "yes"},
{"score": 0.0835, "label": "right"},
]
EXPECTED_OUTPUT_PT_2 = [
{"score": 0.0845, "label": "stop"},
{"score": 0.0844, "label": "on"},
{"score": 0.0841, "label": "right"},
{"score": 0.0834, "label": "left"},
]
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])

audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
output = audio_classifier(audio_dict, top_k=4)
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])

@require_torch
@slow
def test_large_model_pt(self):
Expand Down

0 comments on commit 83d2a94

Please sign in to comment.