From 0c72a56e9adbd863e8ce8ac2198c4edeb0f107e3 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 24 Feb 2023 22:32:03 +0530 Subject: [PATCH] Add ORTModelXXX for audio (#774) * add ctc and sequence classification model * add xvector * add audio frame classification * update docstring documentation * updated docs * add audio classification test * update models * update models * add audio classification test * add audio classification test * add audio classification test * fixed tests * fix tests * updatte model * rebase * update test * fix test * fixed tests * update docs * fix code docstring * update docs * add numpy tests * add error for iobinding * update docs * update docs --------- Co-authored-by: Mohit Sharma --- .../package_reference/modeling_ort.mdx | 83 +++- optimum/exporters/onnx/model_configs.py | 30 +- optimum/onnxruntime/__init__.py | 8 + optimum/onnxruntime/base.py | 132 ++++-- optimum/onnxruntime/modeling_ort.py | 390 ++++++++++++++++ optimum/onnxruntime/modeling_seq2seq.py | 70 ++- optimum/pipelines.py | 12 +- optimum/utils/normalized_config.py | 8 + setup.py | 11 +- tests/onnxruntime/test_modeling.py | 435 +++++++++++++++++- tests/onnxruntime/utils_onnxruntime_tests.py | 4 +- 11 files changed, 1087 insertions(+), 96 deletions(-) diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index 5e0136bb05..54927111f2 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -12,57 +12,98 @@ specific language governing permissions and limitations under the License. # Models -## ORTModel +## Generic model classes + +The following ORT classes are available for instantiating a base model class without a specific head. + +### ORTModel [[autodoc]] onnxruntime.ORTModel -## ORTModelForCausalLM +## Natural Language Processing + +The following ORT classes are available for the following natural language processing tasks. + +### ORTModelForCausalLM [[autodoc]] onnxruntime.ORTModelForCausalLM -## ORTModelForCustomTasks +### ORTModelForMaskedLM -[[autodoc]] onnxruntime.ORTModelForCustomTasks +[[autodoc]] onnxruntime.ORTModelForMaskedLM -## ORTModelForFeatureExtraction +### ORTModelForSeq2SeqLM -[[autodoc]] onnxruntime.ORTModelForFeatureExtraction +[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM -## ORTModelForImageClassification +### ORTModelForSequenceClassification -[[autodoc]] onnxruntime.ORTModelForImageClassification +[[autodoc]] onnxruntime.ORTModelForSequenceClassification -## ORTModelForMaskedLM -[[autodoc]] onnxruntime.ORTModelForMaskedLM +### ORTModelForTokenClassification -## ORTModelForMultipleChoice +[[autodoc]] onnxruntime.ORTModelForTokenClassification + +### ORTModelForMultipleChoice [[autodoc]] onnxruntime.ORTModelForMultipleChoice -## ORTModelForQuestionAnswering +## Computer vision + +The following ORT classes are available for the following computer vision tasks. + +### ORTModelForQuestionAnswering [[autodoc]] onnxruntime.ORTModelForQuestionAnswering -## ORTModelForSemanticSegmentation +### ORTModelForImageClassification + +[[autodoc]] onnxruntime.ORTModelForImageClassification + +### ORTModelForSemanticSegmentation [[autodoc]] onnxruntime.ORTModelForSemanticSegmentation -## ORTModelForSeq2SeqLM +## Audio -[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM +The following ORT classes are available for the following audio tasks. -## ORTModelForSequenceClassification +### ORTModelForAudioClassification -[[autodoc]] onnxruntime.ORTModelForSequenceClassification +[[autodoc]] onnxruntime.ORTModelForAudioClassification + +### ORTModelForAudioFrameClassification + +[[autodoc]] onnxruntime.ORTModelForAudioFrameClassification + +### ORTModelForCTC -## ORTModelForSpeechSeq2Seq +[[autodoc]] onnxruntime.ORTModelForCTC + +### ORTModelForSpeechSeq2Seq [[autodoc]] onnxruntime.ORTModelForSpeechSeq2Seq -## ORTModelForTokenClassification +### ORTModelForAudioXVector -[[autodoc]] onnxruntime.ORTModelForTokenClassification +[[autodoc]] onnxruntime.ORTModelForAudioXVector + +## Multimodal + +The following ORT classes are available for the following multimodal tasks. -## ORTModelForVision2Seq +### ORTModelForVision2Seq [[autodoc]] onnxruntime.ORTModelForVision2Seq + +## Custom Tasks + +The following ORT classes are available for the following custom tasks. + +#### ORTModelForCustomTasks + +[[autodoc]] onnxruntime.ORTModelForCustomTasks + +#### ORTModelForFeatureExtraction + +[[autodoc]] onnxruntime.ORTModelForFeatureExtraction diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 590ea9526b..78b50ad9d7 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -958,8 +958,10 @@ class Speech2TextOnnxConfig(AudioToTextOnnxConfig): allow_new=True, ) DUMMY_INPUT_GENERATOR_CLASSES = ( - Speech2TextDummyAudioInputGenerator, - ) + AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:] + (Speech2TextDummyAudioInputGenerator,) + + AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:] + + (DummyTextInputGenerator,) + ) ATOL_FOR_VALIDATION = 1e-4 def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: @@ -978,11 +980,27 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen @property def inputs(self) -> Dict[str, Dict[int, str]]: - common_inputs = super().inputs + common_inputs = {} + + if self._behavior is not ConfigBehavior.DECODER: + common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"} + common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"} + + if self._behavior is not ConfigBehavior.ENCODER: + if self.use_past_in_inputs: + common_inputs["decoder_input_ids"] = {0: "batch_size"} + else: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + if self._behavior is ConfigBehavior.DECODER: - common_inputs["encoder_outputs"][ - 1 - ] = f"{common_inputs['encoder_outputs'][1]} / {( 2 * self._config.num_conv_layers)}" + common_inputs["encoder_outputs"] = { + 0: "batch_size", + 1: f"encoder_sequence_length / {( 2 * self._config.num_conv_layers)}", + } + return common_inputs @property diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 3e51a4e4de..ccc87f2892 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -29,7 +29,11 @@ ], "modeling_ort": [ "ORTModel", + "ORTModelForAudioClassification", + "ORTModelForAudioFrameClassification", + "ORTModelForAudioXVector", "ORTModelForCustomTasks", + "ORTModelForCTC", "ORTModelForFeatureExtraction", "ORTModelForImageClassification", "ORTModelForMaskedLM", @@ -64,6 +68,10 @@ from .modeling_decoder import ORTModelForCausalLM from .modeling_ort import ( ORTModel, + ORTModelForAudioClassification, + ORTModelForAudioFrameClassification, + ORTModelForAudioXVector, + ORTModelForCTC, ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 42585798b2..b7c6c5bee6 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -74,6 +74,9 @@ def forward( attention_mask: torch.LongTensor, **kwargs, ) -> BaseModelOutput: + use_torch = isinstance(input_ids, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.parent_model.use_io_binding: model_inputs = [input_ids] if "attention_mask" in self.input_names: @@ -90,15 +93,25 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} + if use_torch: + onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + # Add the attention_mask inputs when needed + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + else: + onnx_inputs = {"input_ids": input_ids} + + # Add the attention_mask inputs when needed + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask # Run inference outputs = self.session.run(None, onnx_inputs) - last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self.device) + + last_hidden_state = outputs[self.output_names["last_hidden_state"]] + if use_torch: + last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -254,6 +267,8 @@ def forward( use_cache_branch: None = None, ) -> CausalLMOutputWithCrossAttentions: # adding use_cache_branch in the signature here is just a hack for IO Binding + use_torch = isinstance(input_ids, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) # Flatten the past_key_values if past_key_values is not None: @@ -310,21 +325,38 @@ def forward( if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - onnx_inputs = { - "input_ids": input_ids.cpu().detach().numpy(), - "attention_mask": attention_mask.cpu().detach().numpy(), - } + if use_torch: + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": attention_mask.cpu().detach().numpy(), + } + + if self.parent_model.use_merged is True: + onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() + + if past_key_values is not None: + # Add the past_key_values to the decoder inputs + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() + + if "labels" in self.input_names: + onnx_inputs["labels"] = labels.cpu().detach().numpy() + else: + onnx_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() + if self.parent_model.use_merged is True: + onnx_inputs["use_cache_branch"] = use_cache_branch - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() + if past_key_values is not None: + # Add the past_key_values to the decoder inputs + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + onnx_inputs[input_name] = past_key_value - if "labels" in self.input_names: - onnx_inputs["labels"] = labels.cpu().detach().numpy() + if "labels" in self.input_names: + onnx_inputs["labels"] = labels # Run inference outputs = self.session.run(None, onnx_inputs) @@ -382,6 +414,8 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, ) -> Seq2SeqLMOutput: + use_torch = isinstance(input_ids, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) # Flatten the past_key_values if past_key_values is not None: past_key_values = tuple( @@ -449,26 +483,49 @@ def filter_out_output(output_name): if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - onnx_inputs = { - "input_ids": input_ids.cpu().detach().numpy(), - } + # from pdb import set_trace; set_trace() + if use_torch: + onnx_inputs = { + "input_ids": input_ids.cpu().detach().numpy(), + } + + # Add the encoder_attention_mask inputs when needed + if "encoder_attention_mask" in self.input_names: + onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() + + # Add the encoder_hidden_states inputs when needed + if "encoder_hidden_states" in self.input_names: + onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() + + if past_key_values is not None: + # Add the past_key_values to the decoder inputs + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() + + if "labels" in self.input_names: + # TODO: Any preprocessing like `self._shift_right(labels)`? + onnx_inputs["labels"] = labels.cpu().detach().numpy() + else: + onnx_inputs = { + "input_ids": input_ids, + } - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() + # Add the encoder_attention_mask inputs when needed + if "encoder_attention_mask" in self.input_names: + onnx_inputs["encoder_attention_mask"] = encoder_attention_mask - # Add the encoder_hidden_states inputs when needed - if "encoder_hidden_states" in self.input_names: - onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() + # Add the encoder_hidden_states inputs when needed + if "encoder_hidden_states" in self.input_names: + onnx_inputs["encoder_hidden_states"] = encoder_hidden_states - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() + if past_key_values is not None: + # Add the past_key_values to the decoder inputs + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + onnx_inputs[input_name] = past_key_value - if "labels" in self.input_names: - # TODO: Any preprocessing like `self._shift_right(labels)`? - onnx_inputs["labels"] = labels.cpu().detach().numpy() + if "labels" in self.input_names: + # TODO: Any preprocessing like `self._shift_right(labels)`? + onnx_inputs["labels"] = labels # Run inference outputs = self.session.run(None, onnx_inputs) @@ -484,10 +541,15 @@ def filter_out_output(output_name): # cross-attention per decoder layer num_pkv = 4 past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)) - logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) + + logits = outputs[self.output_names["logits"]] + if use_torch: + logits = torch.from_numpy(logits).to(self.device) loss = None if "loss" in self.output_names: - loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) + loss = outputs[self.output_names["loss"]] + if use_torch: + loss = torch.from_numpy(loss).to(self.device) return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=past_key_values) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 5307182441..be3c49d5df 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -27,6 +27,10 @@ from transformers import ( AutoConfig, AutoModel, + AutoModelForAudioClassification, + AutoModelForAudioFrameClassification, + AutoModelForAudioXVector, + AutoModelForCTC, AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, @@ -38,6 +42,7 @@ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import ( BaseModelOutput, + CausalLMOutput, ImageClassifierOutput, MaskedLMOutput, ModelOutput, @@ -46,6 +51,7 @@ SemanticSegmenterOutput, SequenceClassifierOutput, TokenClassifierOutput, + XVectorOutput, ) import onnxruntime as ort @@ -77,6 +83,7 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" _FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" +_PROCESSOR_FOR_DOC = "AutoProcessor" ONNX_MODEL_START_DOCSTRING = r""" This model inherits from [`~onnxruntime.modeling_ort.ORTModel`]. Check the superclass documentation for the generic methods the @@ -117,6 +124,13 @@ Pixel values can be obtained from encoded images using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor). """ +ONNX_AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.Tensor` of shape `({0})`): + Float values of input raw speech waveform.. + Input values can be obtained from audio file loaded into an array using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor). +""" + class classproperty: def __init__(self, getter): @@ -1651,6 +1665,382 @@ def _prepare_onnx_inputs(self, use_torch: bool, **kwargs): return onnx_inputs +AUDIO_CLASSIFICATION_EXAMPLE = r""" + Example of audio classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_ids = torch.argmax(logits, dim=-1).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + ``` + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_ac = pipeline("audio-classification", model=model, feature_extractor=feature_extractor) + + >>> pred = onnx_ac(dataset[0]["audio"]["array"]) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForAudioClassification(ORTModel): + """ + Audio Classification model for ONNX. + """ + + auto_model_class = AutoModelForAudioClassification + + @add_start_docstrings_to_model_forward( + ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_CLASSIFICATION_EXAMPLE.format( + processor_class=_FEATURE_EXTRACTOR_FOR_DOC, + model_class="ORTModelForAudioClassification", + checkpoint="optimum/hubert-base-superb-ks", + ) + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + attenton_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + use_torch = isinstance(input_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + input_values, ordered_input_names=self._ordered_input_names + ) + + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + else: + if use_torch: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_values": input_values.cpu().detach().numpy(), + } + else: + onnx_inputs = { + "input_values": input_values, + } + + # run inference + outputs = self.model.run(None, onnx_inputs) + + logits = outputs[self.output_names["logits"]] + if use_torch: + logits = torch.from_numpy(logits).to(self.device) + + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=logits) + + +CTC_EXAMPLE = r""" + Example of CTC: + + ```python + >>> from transformers import {processor_class}, HubertForCTC + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> predicted_ids = torch.argmax(logits, dim=-1) + + >>> transcription = processor.batch_decode(predicted_ids) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with a language modeling head on top for Connectionist Temporal Classification (CTC). + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForCTC(ORTModel): + """ + CTC model for ONNX. + """ + + auto_model_class = AutoModelForCTC + + @add_start_docstrings_to_model_forward( + ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + CTC_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_DOC, + model_class="ORTModelForCTC", + checkpoint="optimum/hubert-large-ls960-ft", + ) + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + **kwargs, + ): + use_torch = isinstance(input_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + input_values, ordered_input_names=self._ordered_input_names + ) + + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # converts output to namedtuple for pipelines post-processing + return CausalLMOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + else: + if use_torch: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_values": input_values.cpu().detach().numpy(), + } + else: + onnx_inputs = { + "input_values": input_values, + } + + # run inference + outputs = self.model.run(None, onnx_inputs) + + logits = outputs[self.output_names["logits"]] + if use_torch: + logits = torch.from_numpy(logits).to(self.device) + # converts output to namedtuple for pipelines post-processing + return CausalLMOutput(logits=logits) + + +AUDIO_XVECTOR_EXAMPLE = r""" + Example of Audio XVector: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor( + ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True + ... ) + >>> with torch.no_grad(): + ... embeddings = model(**inputs).embeddings + + >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() + + >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1) + >>> similarity = cosine_sim(embeddings[0], embeddings[1]) + >>> threshold = 0.7 + >>> if similarity < threshold: + ... print("Speakers are not the same!") + >>> round(similarity.item(), 2) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForAudioXVector(ORTModel): + """ + Audio XVector model for ONNX. + """ + + auto_model_class = AutoModelForAudioXVector + + @add_start_docstrings_to_model_forward( + ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_XVECTOR_EXAMPLE.format( + processor_class=_FEATURE_EXTRACTOR_FOR_DOC, + model_class="ORTModelForAudioXVector", + checkpoint="optimum/wav2vec2-base-superb-sv", + ) + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + **kwargs, + ): + use_torch = isinstance(input_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + input_values, ordered_input_names=self._ordered_input_names + ) + + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # converts output to namedtuple for pipelines post-processing + return XVectorOutput( + logits=output_buffers["logits"].view(output_shapes["logits"]), + embeddings=output_buffers["embeddings"].view(output_shapes["embeddings"]), + ) + else: + if use_torch: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_values": input_values.cpu().detach().numpy(), + } + else: + onnx_inputs = { + "input_values": input_values, + } + + # run inference + outputs = self.model.run(None, onnx_inputs) + + logits = outputs[self.output_names["logits"]] + embeddings = outputs[self.output_names["embeddings"]] + if use_torch: + logits = torch.from_numpy(logits).to(self.device) + embeddings = torch.from_numpy(embeddings).to(self.device) + + # converts output to namedtuple for pipelines post-processing + return XVectorOutput(logits=logits, embeddings=embeddings) + + +AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r""" + Example of audio frame classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate) + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> probabilities = torch.sigmoid(logits[0]) + >>> labels = (probabilities > 0.5).long() + >>> labels[0].tolist() + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model for with a frame classification head on top for tasks like Speaker Diarization. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForAudioFrameClassification(ORTModel): + """ + Audio Frame Classification model for ONNX. + """ + + auto_model_class = AutoModelForAudioFrameClassification + + @add_start_docstrings_to_model_forward( + ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_FRAME_CLASSIFICATION_EXAMPLE.format( + processor_class=_FEATURE_EXTRACTOR_FOR_DOC, + model_class="ORTModelForAudioFrameClassification", + checkpoint="optimum/wav2vec2-base-superb-sd", + ) + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + **kwargs, + ): + use_torch = isinstance(input_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + + if self.device.type == "cuda" and self.use_io_binding: + raise NotImplementedError() + else: + if use_torch: + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "input_values": input_values.cpu().detach().numpy(), + } + else: + onnx_inputs = { + "input_values": input_values, + } + + # run inference + outputs = self.model.run(None, onnx_inputs) + + logits = outputs[self.output_names["logits"]] + if use_torch: + logits = torch.from_numpy(logits).to(self.device) + # converts output to namedtuple for pipelines post-processing + return TokenClassifierOutput(logits=logits) + + CUSTOM_TASKS_EXAMPLE = r""" Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output): diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 5645eaddf5..c23acfaa0f 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -23,6 +23,7 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from huggingface_hub import hf_hub_download from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig @@ -74,10 +75,10 @@ `(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`. """ -WHISPER_ENCODER_INPUTS_DOCSTRING = r""" +SPEECH_ENCODER_INPUTS_DOCSTRING = r""" Args: input_features (`torch.FloatTensor`): - Mel features extracted from the raw speech waveform. `(batch_size, feature_size, encoder_sequence_length)`. + Mel / fbank features extracted from the raw speech waveform. `(batch_size, feature_size, encoder_sequence_length)`. """ VISION_ENCODER_INPUTS_DOCSTRING = r""" @@ -273,7 +274,7 @@ DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" -class ORTEncoderForWhisper(ORTEncoder): +class ORTEncoderForSpeech(ORTEncoder): """ Encoder model for ONNX Runtime inference for Whisper model. @@ -282,15 +283,24 @@ class ORTEncoderForWhisper(ORTEncoder): The ONNX Runtime inference session associated to the encoder. """ - @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SPEECH_ENCODER_INPUTS_DOCSTRING) def forward( self, input_features: torch.FloatTensor, + attention_mask: torch.LongTensor, **kwargs, ) -> BaseModelOutput: + use_torch = isinstance(input_features, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: + model_inputs = ( + [input_features, attention_mask] if "attention_mask" in self.input_names else [input_features] + ) io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, input_features, ordered_input_names=self._ordered_input_names + self.session, + *model_inputs, + ordered_input_names=self._ordered_input_names, ) io_binding.synchronize_inputs() @@ -299,10 +309,27 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} + if use_torch: + onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() + else: + onnx_inputs = {"input_features": input_features} + if "attention_mask" in self.input_names: + onnx_inputs["attention_mask"] = attention_mask + + # TODO: Replace with a better solution + # attention_mask is exported with int64 datatype and tokenizer produces int32 input + # for speech2text model. Hence, the input is type casted for inference. + if "attention_mask" in self.input_names: + if self.session.get_inputs()[1].type == "tensor(int64)": + onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64) outputs = self.session.run(None, onnx_inputs) - last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self.device) + + last_hidden_state = outputs[self.output_names["last_hidden_state"]] + if use_torch: + last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -322,6 +349,9 @@ def forward( pixel_values: torch.FloatTensor, **kwargs, ) -> BaseModelOutput: + use_torch = isinstance(pixel_values, torch.Tensor) + self.parent_model.raise_on_numpy_input_io_binding(use_torch) + if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding: io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( self.session, pixel_values, ordered_input_names=self._ordered_input_names @@ -333,10 +363,16 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - onnx_inputs = {"pixel_values": pixel_values.cpu().detach().numpy()} + if use_torch: + onnx_inputs = {"pixel_values": pixel_values.cpu().detach().numpy()} + else: + onnx_inputs = {"pixel_values": pixel_values} outputs = self.session.run(None, onnx_inputs) - last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self.device) + + last_hidden_state = outputs[self.output_names["last_hidden_state"]] + if use_torch: + last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -934,18 +970,8 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin auto_model_class = AutoModelForSpeechSeq2Seq main_input_name = "input_features" - _MODEL_TYPE_TO_ORTENCODER = { - "whisper": ORTEncoderForWhisper, - } - def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: - if self.config.model_type not in self._MODEL_TYPE_TO_ORTENCODER: - raise KeyError( - f"{self.config.model_type} is not supported yet. " - f"Only {list(self._MODEL_TYPE_TO_ORTENCODER.keys())} are supported. " - f"If you want to support {self.config.model_type} please propose a PR or open up an issue." - ) - return self._MODEL_TYPE_TO_ORTENCODER[self.config.model_type](session, self) + return ORTEncoderForSpeech(session, self) @add_start_docstrings_to_model_forward( SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, feature_size, sequence_length") @@ -958,6 +984,7 @@ def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: def forward( self, input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, @@ -966,7 +993,7 @@ def forward( ) -> Seq2SeqLMOutput: # Encode if needed : first prediction pass if encoder_outputs is None: - encoder_outputs = self.encoder(input_features=input_features) + encoder_outputs = self.encoder(input_features=input_features, attention_mask=attention_mask) # Decode if past_key_values is None or self.decoder_with_past is None: @@ -992,6 +1019,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, + attention_mask=None, past_key_values=None, head_mask=None, decoder_head_mask=None, diff --git a/optimum/pipelines.py b/optimum/pipelines.py index d15acee477..2e24663754 100644 --- a/optimum/pipelines.py +++ b/optimum/pipelines.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Optional, Union from transformers import ( + AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline, FeatureExtractionPipeline, FillMaskPipeline, @@ -49,6 +50,7 @@ if is_onnxruntime_available(): from .onnxruntime import ( + ORTModelForAudioClassification, ORTModelForCausalLM, ORTModelForFeatureExtraction, ORTModelForImageClassification, @@ -148,6 +150,12 @@ "default": "nlpconnect/vit-gpt2-image-captioning", "type": "multimodal", }, + "audio-classification": { + "impl": AudioClassificationPipeline, + "class": (ORTModelForAudioClassification,), + "default": "superb/hubert-base-superb-ks", + "type": "audio", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() @@ -155,7 +163,9 @@ for task, values in SUPPORTED_TASKS.items(): if values["type"] == "text": NO_FEATURE_EXTRACTOR_TASKS.add(task) - elif values["type"] == "image": + elif values["type"] in {"image", "video"}: + NO_TOKENIZER_TASKS.add(task) + elif values["type"] in {"audio"}: NO_TOKENIZER_TASKS.add(task) elif values["type"] != "multimodal": raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index c29441eb59..3855e392f9 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -140,6 +140,13 @@ def __getattr__(self, attr_name): hidden_size="cross_attention_hidden_size", ) +SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args( + decoder_num_layers="decoder_layers", + num_layers="decoder_layers", + input_features_per_channel="input_feat_per_channel", + allow_new=True, +) + class NormalizedConfigManager: """ @@ -205,6 +212,7 @@ class NormalizedConfigManager: "poolformer": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, + "speech_to_text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, diff --git a/setup.py b/setup.py index 31df82d6ab..b530baded4 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,16 @@ "datasets", ] -TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow", "sacremoses", "diffusers"] +TESTS_REQUIRE = [ + "pytest", + "requests", + "parameterized", + "pytest-xdist", + "Pillow", + "sacremoses", + "diffusers", + "torchaudio", +] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 5a41948250..d80f410162 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -33,9 +33,14 @@ from PIL import Image from transformers import ( AutoConfig, + AutoFeatureExtractor, AutoImageProcessor, AutoModel, + AutoModelForAudioClassification, + AutoModelForAudioFrameClassification, + AutoModelForAudioXVector, AutoModelForCausalLM, + AutoModelForCTC, AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, @@ -64,7 +69,11 @@ ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, ONNX_WEIGHTS_NAME, + ORTModelForAudioClassification, + ORTModelForAudioFrameClassification, + ORTModelForAudioXVector, ORTModelForCausalLM, + ORTModelForCTC, ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, @@ -2404,6 +2413,397 @@ def test_compare_to_io_binding(self, model_arch): gc.collect() +class ORTModelForAudioClassificationIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "audio_spectrogram_transformer", + "data2vec_audio", + "hubert", + "sew", + "sew_d", + "unispeech", + "unispeech_sat", + "wavlm", + "wav2vec2", + "wav2vec2-conformer", + ] + + FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} + ORTMODEL_CLASS = ORTModelForAudioClassification + TASK = "audio-classification" + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForAudioClassification.from_pretrained(MODEL_NAMES["t5"], from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioClassification.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + transformers_model = AutoModelForAudioClassification.from_pretrained(model_id) + processor = AutoFeatureExtractor.from_pretrained(model_id) + + input_values = processor(self._generate_random_audio_data(), return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**input_values) + + for input_type in ["pt", "np"]: + input_values = processor(self._generate_random_audio_data(), return_tensors=input_type) + onnx_outputs = onnx_model(**input_values) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + # compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_ort_model(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioClassification.from_pretrained(self.onnx_model_dirs[model_arch]) + processor = AutoFeatureExtractor.from_pretrained(model_id) + pipe = pipeline("audio-classification", model=onnx_model, feature_extractor=processor, sampling_rate=220) + data = self._generate_random_audio_data() + outputs = pipe(data) + + self.assertEqual(pipe.device, onnx_model.device) + + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertIsInstance(outputs[0]["label"], str) + + gc.collect() + + @pytest.mark.run_in_series + def test_pipeline_model_is_none(self): + pipe = pipeline("audio-classification") + data = self._generate_random_audio_data() + outputs = pipe(data) + + # compare model output class + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertIsInstance(outputs[0]["label"], str) + + @parameterized.expand( + grid_parameters( + {"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider", "TensorrtExecutionProvider"]} + ) + ) + @require_torch_gpu + @pytest.mark.gpu_test + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + if provider == "TensorrtExecutionProvider" and model_arch != self.__class__.SUPPORTED_ARCHITECTURES[0]: + self.skipTest("testing a single arch for TensorrtExecutionProvider") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioClassification.from_pretrained( + self.onnx_model_dirs[model_arch], provider=provider + ) + processor = AutoFeatureExtractor.from_pretrained(model_id) + pipe = pipeline("audio-classification", model=onnx_model, feature_extractor=processor, device=0) + data = self._generate_random_audio_data() + outputs = pipe(data) + # check model device + self.assertEqual(pipe.model.device.type.lower(), "cuda") + # compare model output class + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertTrue(isinstance(outputs[0]["label"], str)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + @pytest.mark.gpu_test + def test_compare_to_io_binding(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioClassification.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=False + ).to("cuda") + io_model = ORTModelForAudioClassification.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=True + ).to("cuda") + + processor = AutoFeatureExtractor.from_pretrained(model_id) + data = self._generate_random_audio_data() + + input_values = processor(data, return_tensors="pt") + onnx_outputs = onnx_model(**input_values) + io_outputs = io_model(**input_values) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, io_outputs.logits, atol=1e-4)) + + gc.collect() + + +class ORTModelForCTCIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "data2vec_audio", + "hubert", + "sew", + "sew_d", + "unispeech", + "unispeech_sat", + "wavlm", + "wav2vec2", + "wav2vec2-conformer", + ] + + FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} + ORTMODEL_CLASS = ORTModelForCTC + TASK = "ctc" + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForCTC.from_pretrained(MODEL_NAMES["t5"], from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + transformers_model = AutoModelForCTC.from_pretrained(model_id) + processor = AutoFeatureExtractor.from_pretrained(model_id) + + input_values = processor(self._generate_random_audio_data(), return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**input_values) + + for input_type in ["pt", "np"]: + input_values = processor(self._generate_random_audio_data(), return_tensors=input_type) + onnx_outputs = onnx_model(**input_values) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + # compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + @pytest.mark.gpu_test + def test_compare_to_io_binding(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=False).to("cuda") + io_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to("cuda") + + processor = AutoFeatureExtractor.from_pretrained(model_id) + data = self._generate_random_audio_data() + + input_values = processor(data, return_tensors="pt") + onnx_outputs = onnx_model(**input_values) + io_outputs = io_model(**input_values) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, io_outputs.logits, atol=1e-4)) + + gc.collect() + + +class ORTModelForAudioXVectorIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "data2vec_audio", + "unispeech_sat", + "wavlm", + "wav2vec2", + "wav2vec2-conformer", + ] + + FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} + ORTMODEL_CLASS = ORTModelForAudioXVector + TASK = "audio-xvector" + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForAudioXVector.from_pretrained(MODEL_NAMES["t5"], from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioXVector.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + transformers_model = AutoModelForAudioXVector.from_pretrained(model_id) + processor = AutoFeatureExtractor.from_pretrained(model_id) + input_values = processor(self._generate_random_audio_data(), return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**input_values) + for input_type in ["pt", "np"]: + input_values = processor(self._generate_random_audio_data(), return_tensors=input_type) + onnx_outputs = onnx_model(**input_values) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertIsInstance(onnx_outputs.embeddings, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + # compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + self.assertTrue( + torch.allclose(torch.Tensor(onnx_outputs.embeddings), transformers_outputs.embeddings, atol=1e-4) + ) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + @pytest.mark.gpu_test + def test_compare_to_io_binding(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioXVector.from_pretrained( + self.onnx_model_dirs[model_arch], use_io_binding=False + ).to("cuda") + io_model = ORTModelForAudioXVector.from_pretrained(self.onnx_model_dirs[model_arch], use_io_binding=True).to( + "cuda" + ) + + processor = AutoFeatureExtractor.from_pretrained(model_id) + data = self._generate_random_audio_data() + + input_values = processor(data, return_tensors="pt") + onnx_outputs = onnx_model(**input_values) + io_outputs = io_model(**input_values) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + self.assertIsInstance(io_outputs.embeddings, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, io_outputs.logits, atol=1e-4)) + self.assertTrue(torch.allclose(onnx_outputs.embeddings, io_outputs.embeddings, atol=1e-4)) + gc.collect() + + +class ORTModelForAudioFrameClassificationIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "data2vec_audio", + "unispeech_sat", + "wavlm", + "wav2vec2", + "wav2vec2-conformer", + ] + + FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} + ORTMODEL_CLASS = ORTModelForAudioFrameClassification + TASK = "audio-frame-classification" + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForAudioFrameClassification.from_pretrained(MODEL_NAMES["t5"], from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + onnx_model = ORTModelForAudioFrameClassification.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + transformers_model = AutoModelForAudioFrameClassification.from_pretrained(model_id) + processor = AutoFeatureExtractor.from_pretrained(model_id) + input_values = processor(self._generate_random_audio_data(), return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**input_values) + for input_type in ["pt", "np"]: + input_values = processor(self._generate_random_audio_data(), return_tensors=input_type) + onnx_outputs = onnx_model(**input_values) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + # compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + + gc.collect() + + class ORTModelForSeq2SeqLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [ "bart", @@ -2718,7 +3118,7 @@ def test_compare_generation_to_io_binding(self, test_name: str, model_arch: str, class ORTModelForSpeechSeq2SeqIntegrationTest(ORTModelTestMixin): # TODO: speech_to_text should be tested - SUPPORTED_ARCHITECTURES = ["whisper"] + SUPPORTED_ARCHITECTURES = ["whisper", "speech_to_text"] FULL_GRID = { "model_arch": SUPPORTED_ARCHITECTURES, @@ -2733,9 +3133,10 @@ class ORTModelForSpeechSeq2SeqIntegrationTest(ORTModelTestMixin): def _generate_random_audio_data(self): np.random.seed(10) - t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + t = np.linspace(0, 5.0, int(5.0 * 18736), endpoint=False) # generate pure sine wave at 220 Hz audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data def test_load_vanilla_transformers_which_is_not_supported(self): @@ -2781,19 +3182,28 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach processor = get_preprocessor(model_id) data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") decoder_start_token_id = transformers_model.config.decoder_start_token_id decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} - onnx_outputs = onnx_model(**features, **decoder_inputs) - - self.assertTrue("logits" in onnx_outputs) - self.assertIsInstance(onnx_outputs.logits, torch.Tensor) with torch.no_grad(): transformers_outputs = transformers_model(**features, **decoder_inputs) - # Compare tensor outputs - self.assertTrue(torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4)) + + for input_type in ["pt", "np"]: + features = processor.feature_extractor(data, return_tensors=input_type) + + if input_type == "np": + decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} + + onnx_outputs = onnx_model(**features, **decoder_inputs) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + # Compare tensor outputs + self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) gc.collect() @@ -2863,6 +3273,9 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): model_with_pkv = ORTModelForSpeechSeq2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_True"], use_cache=True ) + + generation_length = self.GENERATION_LENGTH + self.GENERATION_LENGTH = 10 _ = model_with_pkv.generate(**features) # warpup with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( @@ -2881,7 +3294,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) - + self.GENERATION_LENGTH = generation_length if os.environ.get("TEST_LEVEL", 0) == "1": self.assertTrue( without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, @@ -3244,6 +3657,10 @@ class TestBothExportersORTModel(unittest.TestCase): ["semantic-segmentation", ORTModelForSemanticSegmentationIntegrationTest], ["seq2seq-lm", ORTModelForSeq2SeqLMIntegrationTest], ["speech2seq-lm", ORTModelForSpeechSeq2SeqIntegrationTest], + ["audio-classification", ORTModelForAudioClassificationIntegrationTest], + ["audio-ctc", ORTModelForCTCIntegrationTest], + ["audio-xvector", ORTModelForAudioXVectorIntegrationTest], + ["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest], ] ) def test_find_untested_architectures(self, task: str, test_class): diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 99bef6061c..8ea3a78435 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -73,12 +73,12 @@ "hubert": "hf-internal-testing/tiny-random-HubertModel", "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", - "wavlm": "hf-internal-testing/tiny-random-wavlm", + "wavlm": "hf-internal-testing/tiny-random-WavlmModel", "sew": "hf-internal-testing/tiny-random-SEWModel", "sew_d": "hf-internal-testing/tiny-random-SEWDModel", "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "unispeech": "hf-internal-testing/tiny-random-unispeech", - "unispeech_sat": "hf-internal-testing/tiny-random-unispeech-sat", + "unispeech_sat": "hf-internal-testing/tiny-random-UnispeechSatModel", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",