Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ONNX export on torch.float16 type #749

Merged
merged 18 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def parse_args_onnx(parser):
default="cpu",
help='The device to use to do the export. Defaults to "cpu".',
)
optional_group.add_argument(
"--dtype",
type=str,
default=None,
choices=["float32", "float16", None],
help="Experimental option: the dtype of the weights to use during the export. If None, the default dtype will be used. PyTorch-only.",
)
optional_group.add_argument(
"--opset",
type=int,
Expand Down
19 changes: 18 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
str_dtype_to_torch_dtype,
)


Expand Down Expand Up @@ -59,13 +60,27 @@ def main():
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if args.framework == "tf" and args.dtype is not None:
raise ValueError("The --dtype option is supported only for PyTorch.")

if args.dtype is not None and args.device == "cpu":
raise ValueError(
"The --dtype option is supported on when exporting on GPU. Please pass the option --device cuda."
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = getattr(args, input_name)

torch_dtype = str_dtype_to_torch_dtype[args.dtype]
model = TasksManager.get_model_from_task(
task, args.model, framework=args.framework, cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code
task,
args.model,
framework=args.framework,
cache_dir=args.cache_dir,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch_dtype,
)

if task != "stable-diffusion":
Expand Down Expand Up @@ -143,6 +158,7 @@ def main():
output_names=output_names,
input_shapes=input_shapes,
device=args.device,
dtype=torch_dtype,
)
else:
onnx_inputs, onnx_outputs = export(
Expand All @@ -152,6 +168,7 @@ def main():
opset=args.opset,
input_shapes=input_shapes,
device=args.device,
dtype=torch_dtype,
)

try:
Expand Down
46 changes: 37 additions & 9 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_tf_available, is_torch_available
Expand All @@ -27,11 +27,14 @@

from ...onnx.utils import _get_onnx_external_data_tensors, check_model_uses_external_data
from ...utils import TORCH_MINIMUM_VERSION, is_diffusers_available, is_torch_onnx_support_available, logging
from ..error_utils import AtolError, NumberOfInputsMatchError, NumberOfOutputsMatchError, OutputMatchError, ShapeError
from ..error_utils import AtolError, OutputMatchError, ShapeError
from .base import OnnxConfig
from .utils import recursive_to_device
from .utils import recursive_to_device, recursive_to_dtype


if TYPE_CHECKING:
import torch

if is_torch_available():
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -81,6 +84,7 @@ def validate_models_outputs(
output_names: Optional[List[str]] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
):
"""
Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
Expand All @@ -102,6 +106,8 @@ def validate_models_outputs(
If specified, allows to use specific shapes to validate the ONNX model on.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX models will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type of the inputs to perform validation on. Validation on float16 is supported only for PyTorch.

Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
Expand Down Expand Up @@ -131,6 +137,7 @@ def validate_models_outputs(
atol=atol,
input_shapes=input_shapes,
device=device,
dtype=dtype,
)


Expand All @@ -142,6 +149,7 @@ def validate_model_outputs(
atol: Optional[float] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
Expand All @@ -161,6 +169,8 @@ def validate_model_outputs(
If specified, allows to use specific shapes to validate the ONNX model on.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type of the inputs to perform validation on. Validation on float16 is supported only for PyTorch.

Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
Expand Down Expand Up @@ -196,7 +206,8 @@ def validate_model_outputs(
reference_model.to(device)

for key, value in reference_model_inputs.items():
reference_model_inputs[key] = recursive_to_device(value=value, device=device)
reference_model_inputs[key] = recursive_to_dtype(value=value, dtype=dtype)
reference_model_inputs[key] = recursive_to_device(value=reference_model_inputs[key], device=device)

ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}
Expand Down Expand Up @@ -290,6 +301,7 @@ def export_pytorch(
output: Path,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
dtype: Optional["torch.dtype"] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an ONNX Intermediate Representation.
Expand All @@ -306,8 +318,10 @@ def export_pytorch(
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`optional[Dict]`, defaults to `None`):
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to remap the model inputs to.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -337,11 +351,18 @@ def export_pytorch(
# Check that inputs match, and order them properly
dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)
device = torch.device(device)

def remap(value):
if isinstance(value, torch.Tensor):
value = value.to(device)
if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
value = value.to(dtype=dtype)

return value

if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
)
dummy_inputs = tree_map(remap, dummy_inputs)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -482,6 +503,7 @@ def export_models(
output_names: Optional[List[str]] = None,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
dtype: Optional["torch.dtype"] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
Expand All @@ -503,6 +525,8 @@ def export_models(
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `float16` is supported.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
inputs from the ONNX configuration.
Expand All @@ -529,6 +553,7 @@ def export_models(
opset=opset,
device=device,
input_shapes=input_shapes,
dtype=dtype,
)
)

Expand All @@ -543,6 +568,7 @@ def export(
opset: Optional[int] = None,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
dtype: Optional["torch.dtype"] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation.
Expand All @@ -561,6 +587,8 @@ def export(
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `float16` is supported.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand Down Expand Up @@ -593,7 +621,7 @@ def export(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
f" got: {torch.__version__}"
)
return export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes)
return export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes, dtype=dtype)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
if device == "cuda":
Expand Down
29 changes: 27 additions & 2 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
"""Utility functions."""

import copy
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import packaging
import torch
from transformers.utils import is_tf_available, is_torch_available

from ...utils import ORT_QUANTIZE_MINIMUM_VERSION, TORCH_MINIMUM_VERSION, is_diffusers_available
from ...utils import ORT_QUANTIZE_MINIMUM_VERSION, is_diffusers_available
from ..tasks import TasksManager


Expand Down Expand Up @@ -199,3 +199,28 @@ def recursive_to_device(value: Union[Tuple, List, "torch.Tensor"], device: str):
value = value.to(device)

return value


def recursive_to_dtype(value: Union[Tuple, List, "torch.Tensor"], dtype: Optional[torch.dtype]):
if dtype is None:
return value

if isinstance(value, tuple):
value = list(value)
for i, val in enumerate(value):
value[i] = recursive_to_dtype(val, dtype)
value = tuple(value)
elif isinstance(value, list):
for i, val in enumerate(value):
value[i] = recursive_to_dtype(val, dtype)
elif isinstance(value, torch.Tensor):
value = value.to(dtype=dtype)

return value


str_dtype_to_torch_dtype = {
None: None,
"float16": torch.float16,
"float32": torch.float32,
}
6 changes: 6 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel, TFPreTrainedModel

from .base import ExportConfig
Expand Down Expand Up @@ -976,6 +977,7 @@ def get_model_from_task(
revision: Optional[str] = None,
framework: Optional[str] = None,
cache_dir: Optional[str] = None,
torch_dtype: Optional["torch.dtype"] = None,
**model_kwargs
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Expand All @@ -997,6 +999,8 @@ def get_model_from_task(
none be provided.
cache_dir (`Optional[str]`, *optional*):
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
torch_dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to load the model on. PyTorch-only argument.
model_kwargs (`Dict[str, Any]`, *optional*):
Keyword arguments to pass to the model `.from_pretrained()` method.

Expand All @@ -1010,6 +1014,8 @@ def get_model_from_task(
model_class = TasksManager.get_model_class_for_task(task, framework)
kwargs = {"subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, **model_kwargs}
try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
Expand Down
11 changes: 10 additions & 1 deletion tests/exporters/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Dict, Optional

from transformers import is_torch_available
from transformers.testing_utils import require_torch, require_vision
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision

from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME
from parameterized import parameterized
Expand Down Expand Up @@ -154,3 +154,12 @@ def test_trust_remote_code(self):
shell=True,
check=True,
)

@require_torch_gpu
def test_export_on_float16(self):
with TemporaryDirectory() as tmpdirname:
_ = subprocess.run(
f"python3 -m optimum.exporters.onnx --model hf-internal-testing/tiny-random-t5 --device cuda --dtype float16 --task seq2seq-lm-with-past {tmpdirname}",
shell=True,
check=True,
)