diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index 1911b1f879..a705d24a1e 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import traceback +from typing import TYPE_CHECKING import numpy as np import torch @@ -24,6 +25,9 @@ from ..utils import is_cupy_available, is_onnxruntime_training_available +if TYPE_CHECKING: + from ..modeling_ort import ORTModel + if is_cupy_available(): import cupy as cp @@ -147,3 +151,39 @@ def get_device_index(device): elif isinstance(device, int): return device return 0 if device.index is None else device.index + + @staticmethod + def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding: + """ + Returns an IOBinding object for an inference session. This method is for general purpose, if the inputs and outputs + are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. + """ + if not all(input_name in inputs.keys() for input_name in ort_model.model_input_names): + raise ValueError( + f"The ONNX model takes {ort_model.model_input_names} as inputs, but only {inputs.keys()} are given." + ) + + name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_model.model) + + # Bind inputs and outputs to onnxruntime session + io_binding = ort_model.model.io_binding() + + # Bind inputs + for input_name in ort_model.model_input_names: + onnx_input = inputs.pop(input_name) + onnx_input = onnx_input.contiguous() + + io_binding.bind_input( + input_name, + onnx_input.device.type, + ort_model.device.index, + name_to_np_type[input_name], + list(onnx_input.size()), + onnx_input.data_ptr(), + ) + + # Bind outputs + for name in ort_model.model_output_names: + io_binding.bind_output(name, ort_model.device.type, device_id=ort_model.device.index) + + return io_binding diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index f696346d1d..84ce14bbd4 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1576,13 +1576,13 @@ def forward( @add_start_docstrings( """ - Onnx Model for any custom tasks. It can be used to leverage the inference acceleration with any custom exported ONNX model. + ONNX Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model. """, ONNX_MODEL_START_DOCSTRING, ) class ORTModelForCustomTasks(ORTModel): """ - Onnx Model for any custom tasks using encoder or decoder-only models. + Model for any custom tasks if the ONNX model is stored in a single file. """ def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): @@ -1592,37 +1592,6 @@ def __init__(self, model=None, config=None, use_io_binding=True, **kwargs): self.model_input_names = list(self.model_inputs.keys()) self.model_output_names = list(self.model_outputs.keys()) - def prepare_io_binding(self, **kwargs) -> ort.IOBinding: - """ - Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs - are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks. - """ - - name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model) - - # Bind inputs and outputs to onnxruntime session - io_binding = self.model.io_binding() - - # Bind inputs - for input_name in self.model_input_names: - onnx_input = kwargs.pop(input_name) - onnx_input = onnx_input.contiguous() - - io_binding.bind_input( - input_name, - onnx_input.device.type, - self.device.index, - name_to_np_type[input_name], - list(onnx_input.size()), - onnx_input.data_ptr(), - ) - - # Bind outputs - for name in self.model_output_names: - io_binding.bind_output(name, self.device.type, device_id=self.device.index) - - return io_binding - @add_start_docstrings_to_model_forward( CUSTOM_TASKS_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, @@ -1632,7 +1601,7 @@ def prepare_io_binding(self, **kwargs) -> ort.IOBinding: ) def forward(self, **kwargs): if self.device.type == "cuda" and self.use_io_binding: - io_binding = self.prepare_io_binding(**kwargs) + io_binding = IOBindingHelper.prepare_io_binding(self, **kwargs) # run inference with binding io_binding.synchronize_inputs()