Skip to content

Commit

Permalink
Add ORTModelXXX for audio (#774)
Browse files Browse the repository at this point in the history
* add ctc and sequence classification model

* add xvector

* add audio frame classification

* update docstring documentation

* updated docs

* add audio classification test

* update models

* update models

* add audio classification test

* add audio classification test

* add audio classification test

* fixed tests

* fix tests

* updatte model

* rebase

* update test

* fix test

* fixed tests

* update docs

* fix code docstring

* update docs

* add numpy tests

* add error for iobinding

* update docs

* update docs

---------

Co-authored-by: Mohit Sharma <mohit@huggingface.co>
  • Loading branch information
mht-sharma and mht-sharma authored Feb 24, 2023
1 parent 0d948ae commit 0c72a56
Show file tree
Hide file tree
Showing 11 changed files with 1,087 additions and 96 deletions.
83 changes: 62 additions & 21 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,98 @@ specific language governing permissions and limitations under the License.

# Models

## ORTModel
## Generic model classes

The following ORT classes are available for instantiating a base model class without a specific head.

### ORTModel

[[autodoc]] onnxruntime.ORTModel

## ORTModelForCausalLM
## Natural Language Processing

The following ORT classes are available for the following natural language processing tasks.

### ORTModelForCausalLM

[[autodoc]] onnxruntime.ORTModelForCausalLM

## ORTModelForCustomTasks
### ORTModelForMaskedLM

[[autodoc]] onnxruntime.ORTModelForCustomTasks
[[autodoc]] onnxruntime.ORTModelForMaskedLM

## ORTModelForFeatureExtraction
### ORTModelForSeq2SeqLM

[[autodoc]] onnxruntime.ORTModelForFeatureExtraction
[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM

## ORTModelForImageClassification
### ORTModelForSequenceClassification

[[autodoc]] onnxruntime.ORTModelForImageClassification
[[autodoc]] onnxruntime.ORTModelForSequenceClassification

## ORTModelForMaskedLM
[[autodoc]] onnxruntime.ORTModelForMaskedLM
### ORTModelForTokenClassification

## ORTModelForMultipleChoice
[[autodoc]] onnxruntime.ORTModelForTokenClassification

### ORTModelForMultipleChoice

[[autodoc]] onnxruntime.ORTModelForMultipleChoice

## ORTModelForQuestionAnswering
## Computer vision

The following ORT classes are available for the following computer vision tasks.

### ORTModelForQuestionAnswering

[[autodoc]] onnxruntime.ORTModelForQuestionAnswering

## ORTModelForSemanticSegmentation
### ORTModelForImageClassification

[[autodoc]] onnxruntime.ORTModelForImageClassification

### ORTModelForSemanticSegmentation

[[autodoc]] onnxruntime.ORTModelForSemanticSegmentation

## ORTModelForSeq2SeqLM
## Audio

[[autodoc]] onnxruntime.ORTModelForSeq2SeqLM
The following ORT classes are available for the following audio tasks.

## ORTModelForSequenceClassification
### ORTModelForAudioClassification

[[autodoc]] onnxruntime.ORTModelForSequenceClassification
[[autodoc]] onnxruntime.ORTModelForAudioClassification

### ORTModelForAudioFrameClassification

[[autodoc]] onnxruntime.ORTModelForAudioFrameClassification

### ORTModelForCTC

## ORTModelForSpeechSeq2Seq
[[autodoc]] onnxruntime.ORTModelForCTC

### ORTModelForSpeechSeq2Seq

[[autodoc]] onnxruntime.ORTModelForSpeechSeq2Seq

## ORTModelForTokenClassification
### ORTModelForAudioXVector

[[autodoc]] onnxruntime.ORTModelForTokenClassification
[[autodoc]] onnxruntime.ORTModelForAudioXVector

## Multimodal

The following ORT classes are available for the following multimodal tasks.

## ORTModelForVision2Seq
### ORTModelForVision2Seq

[[autodoc]] onnxruntime.ORTModelForVision2Seq

## Custom Tasks

The following ORT classes are available for the following custom tasks.

#### ORTModelForCustomTasks

[[autodoc]] onnxruntime.ORTModelForCustomTasks

#### ORTModelForFeatureExtraction

[[autodoc]] onnxruntime.ORTModelForFeatureExtraction
30 changes: 24 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,10 @@ class Speech2TextOnnxConfig(AudioToTextOnnxConfig):
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
Speech2TextDummyAudioInputGenerator,
) + AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
(Speech2TextDummyAudioInputGenerator,)
+ AudioToTextOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:]
+ (DummyTextInputGenerator,)
)
ATOL_FOR_VALIDATION = 1e-4

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
Expand All @@ -978,11 +980,27 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}
common_inputs["attention_mask"] = {0: "batch_size", 1: "encoder_sequence_length"}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"][
1
] = f"{common_inputs['encoder_outputs'][1]} / {( 2 * self._config.num_conv_layers)}"
common_inputs["encoder_outputs"] = {
0: "batch_size",
1: f"encoder_sequence_length / {( 2 * self._config.num_conv_layers)}",
}

return common_inputs

@property
Expand Down
8 changes: 8 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
],
"modeling_ort": [
"ORTModel",
"ORTModelForAudioClassification",
"ORTModelForAudioFrameClassification",
"ORTModelForAudioXVector",
"ORTModelForCustomTasks",
"ORTModelForCTC",
"ORTModelForFeatureExtraction",
"ORTModelForImageClassification",
"ORTModelForMaskedLM",
Expand Down Expand Up @@ -64,6 +68,10 @@
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import (
ORTModel,
ORTModelForAudioClassification,
ORTModelForAudioFrameClassification,
ORTModelForAudioXVector,
ORTModelForCTC,
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
Expand Down
132 changes: 97 additions & 35 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def forward(
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

if self.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = [input_ids]
if "attention_mask" in self.input_names:
Expand All @@ -90,15 +93,25 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}
if use_torch:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_ids": input_ids}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask

# Run inference
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self.device)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -254,6 +267,8 @@ def forward(
use_cache_branch: None = None,
) -> CausalLMOutputWithCrossAttentions:
# adding use_cache_branch in the signature here is just a hack for IO Binding
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

# Flatten the past_key_values
if past_key_values is not None:
Expand Down Expand Up @@ -310,21 +325,38 @@ def forward(
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
}
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
"attention_mask": attention_mask.cpu().detach().numpy(),
}

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy()

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
onnx_inputs["labels"] = labels.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy()
if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "labels" in self.input_names:
onnx_inputs["labels"] = labels.cpu().detach().numpy()
if "labels" in self.input_names:
onnx_inputs["labels"] = labels

# Run inference
outputs = self.session.run(None, onnx_inputs)
Expand Down Expand Up @@ -382,6 +414,8 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
Expand Down Expand Up @@ -449,26 +483,49 @@ def filter_out_output(output_name):
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}
# from pdb import set_trace; set_trace()
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
}

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()
# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()
# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()
if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels

# Run inference
outputs = self.session.run(None, onnx_inputs)
Expand All @@ -484,10 +541,15 @@ def filter_out_output(output_name):
# cross-attention per decoder layer
num_pkv = 4
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))
logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)

logits = outputs[self.output_names["logits"]]
if use_torch:
logits = torch.from_numpy(logits).to(self.device)

loss = None
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)
loss = outputs[self.output_names["loss"]]
if use_torch:
loss = torch.from_numpy(loss).to(self.device)

return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=past_key_values)
Loading

0 comments on commit 0c72a56

Please sign in to comment.