Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move custom IOBinding to IOBindingHelper #571

Merged
merged 4 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions optimum/onnxruntime/io_binding/io_binding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import logging
import traceback
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand All @@ -24,6 +25,9 @@
from ..utils import is_cupy_available, is_onnxruntime_training_available


if TYPE_CHECKING:
from ..modeling_ort import ORTModel

if is_cupy_available():
import cupy as cp

Expand Down Expand Up @@ -147,3 +151,35 @@ def get_device_index(device):
elif isinstance(device, int):
return device
return 0 if device.index is None else device.index

@staticmethod
def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding:
"""
Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs
are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks.
"""

name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_model.model)

# Bind inputs and outputs to onnxruntime session
io_binding = ort_model.model.io_binding()

# Bind inputs
for input_name in ort_model.model_input_names:
onnx_input = inputs.pop(input_name)
onnx_input = onnx_input.contiguous()

io_binding.bind_input(
input_name,
onnx_input.device.type,
ort_model.device.index,
name_to_np_type[input_name],
list(onnx_input.size()),
onnx_input.data_ptr(),
)

# Bind outputs
for name in ort_model.model_output_names:
io_binding.bind_output(name, ort_model.device.type, device_id=ort_model.device.index)

return io_binding
35 changes: 2 additions & 33 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ def forward(

@add_start_docstrings(
"""
Onnx Model for any custom tasks. It can be used to leverage the inference acceleration with any custom exported ONNX model.
ONNX Model for any custom tasks. It can be used to leverage the inference acceleration for any single-file ONNX model.
""",
ONNX_MODEL_START_DOCSTRING,
)
Expand All @@ -1592,37 +1592,6 @@ def __init__(self, model=None, config=None, use_io_binding=True, **kwargs):
self.model_input_names = list(self.model_inputs.keys())
self.model_output_names = list(self.model_outputs.keys())

def prepare_io_binding(self, **kwargs) -> ort.IOBinding:
"""
Returns IOBinding object for an inference session. This method is created for general purpose, if the inputs and outputs
are determined, you can prepare data buffers directly to avoid tensor transfers across frameworks.
"""

name_to_np_type = TypeHelper.get_io_numpy_type_map(self.model)

# Bind inputs and outputs to onnxruntime session
io_binding = self.model.io_binding()

# Bind inputs
for input_name in self.model_input_names:
onnx_input = kwargs.pop(input_name)
onnx_input = onnx_input.contiguous()

io_binding.bind_input(
input_name,
onnx_input.device.type,
self.device.index,
name_to_np_type[input_name],
list(onnx_input.size()),
onnx_input.data_ptr(),
)

# Bind outputs
for name in self.model_output_names:
io_binding.bind_output(name, self.device.type, device_id=self.device.index)

return io_binding

@add_start_docstrings_to_model_forward(
CUSTOM_TASKS_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
Expand All @@ -1632,7 +1601,7 @@ def prepare_io_binding(self, **kwargs) -> ort.IOBinding:
)
def forward(self, **kwargs):
if self.device.type == "cuda" and self.use_io_binding:
io_binding = self.prepare_io_binding(**kwargs)
io_binding = IOBindingHelper.prepare_io_binding(self, **kwargs)

# run inference with binding
io_binding.synchronize_inputs()
Expand Down