Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ONNX export on torch.float16 type #749

Merged
merged 18 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def parse_args_onnx(parser):
default="cpu",
help='The device to use to do the export. Defaults to "cpu".',
)
optional_group.add_argument(
"--fp16",
action="store_true",
help="Experimental option: use half precision during the export. PyTorch-only, requires `--device cuda`.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Experimental because it doesn't work with all models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say experimental because I haven't thouroughly tested it with ONNX Runtime + CUDAExecutionProvider / TensorrtExecutionProvider, and neither with native TensorRT (though in the validation itself we call InferenceSession on CUDA EP, so it's a good sign it's fine). But the export itself is thoroughly tested.

)
optional_group.add_argument(
"--opset",
type=int,
Expand Down
23 changes: 22 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from argparse import ArgumentParser

from transformers import AutoTokenizer
from transformers.utils import is_torch_available

from ...commands.export.onnx import parse_args_onnx
from ...onnxruntime import AutoOptimizationConfig, ORTOptimizer
Expand All @@ -33,6 +34,10 @@
)


if is_torch_available():
import torch


logger = logging.get_logger()
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -64,13 +69,27 @@ def main():
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if (args.framework == "tf" and args.fp16 is True) or not is_torch_available():
raise ValueError("The --fp16 option is supported only for PyTorch.")

if args.fp16 is True and args.device == "cpu":
raise ValueError(
"The --fp16 option is supported only when exporting on GPU. Please pass the option `--device cuda`."
)

# get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = getattr(args, input_name)

torch_dtype = None if args.fp16 is False else torch.float16
model = TasksManager.get_model_from_task(
task, args.model, framework=args.framework, cache_dir=args.cache_dir, trust_remote_code=args.trust_remote_code
task,
args.model,
framework=args.framework,
cache_dir=args.cache_dir,
trust_remote_code=args.trust_remote_code,
torch_dtype=torch_dtype,
)

if task.endswith("-with-past") and args.monolith is True:
Expand Down Expand Up @@ -173,6 +192,7 @@ def main():
output_names=onnx_files_subpaths,
input_shapes=input_shapes,
device=args.device,
dtype="fp16" if args.fp16 is True else None,
)

if args.optimize == "O4" and args.device != "cuda":
Expand Down Expand Up @@ -212,6 +232,7 @@ def main():
onnx_files_subpaths=onnx_files_subpaths,
input_shapes=input_shapes,
device=args.device,
dtype=torch_dtype,
)
logger.info(f"The ONNX export succeeded and the exported model was saved at: {args.output.as_posix()}")
except ShapeError as e:
Expand Down
8 changes: 7 additions & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import onnx
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from transformers.utils import is_torch_available
Expand Down Expand Up @@ -287,7 +288,9 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task]
return copy.deepcopy(common_outputs)

def fix_dynamic_axes(self, model_path: "Path", device: str = "cpu", input_shapes: Dict = None):
def fix_dynamic_axes(
self, model_path: "Path", device: str = "cpu", dtype: Optional[str] = None, input_shapes: Optional[Dict] = None
):
"""
Fixes potential issues with dynamic axes.

Expand Down Expand Up @@ -332,6 +335,9 @@ def fix_dynamic_axes(self, model_path: "Path", device: str = "cpu", input_shapes
onnx_inputs.update({tensor_name: tensor for tensor_name, tensor in value.items()})
else:
onnx_inputs[name] = value
for name, value in onnx_inputs.items():
if value.dtype == np.float32 and dtype == "fp16":
onnx_inputs[name] = onnx_inputs[name].astype(np.float16)
outputs = session.run(None, onnx_inputs)
del session

Expand Down
50 changes: 42 additions & 8 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from ..error_utils import AtolError, OutputMatchError, ShapeError
from .base import OnnxConfig
from .utils import recursive_to_device
from .utils import recursive_to_device, recursive_to_dtype


if is_torch_available():
Expand Down Expand Up @@ -91,6 +91,7 @@ def validate_models_outputs(
onnx_files_subpaths: Optional[List[str]] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
):
"""
Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
Expand All @@ -112,6 +113,8 @@ def validate_models_outputs(
If specified, allows to use specific shapes to validate the ONNX model on.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX models will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch.
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type of the inputs to perform validation on. Validation on float16 is supported only for PyTorch.

Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
Expand Down Expand Up @@ -145,6 +148,7 @@ def validate_models_outputs(
atol=atol,
input_shapes=input_shapes,
device=device,
dtype=dtype,
)
except Exception as e:
exceptions.append(e)
Expand All @@ -163,6 +167,7 @@ def validate_model_outputs(
atol: Optional[float] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
Expand Down Expand Up @@ -252,6 +257,9 @@ def validate_model_outputs(

for key, value in reference_model_inputs.items():
reference_model_inputs[key] = recursive_to_device(value=value, device=device)
reference_model_inputs[key] = recursive_to_dtype(
value=reference_model_inputs[key], dtype=dtype, start_dtype=torch.float32
)

ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {}
Expand Down Expand Up @@ -351,6 +359,7 @@ def export_pytorch(
opset: int,
output: Path,
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
input_shapes: Optional[Dict] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -368,7 +377,9 @@ def export_pytorch(
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`optional[Dict]`, defaults to `None`):
dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `torch.float16` is supported.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.

Returns:
Expand Down Expand Up @@ -399,11 +410,18 @@ def export_pytorch(
dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)

device = torch.device(device)

def remap(value):
if isinstance(value, torch.Tensor):
value = value.to(device)
if isinstance(value, torch.Tensor) and value.dtype == torch.float32:
value = value.to(dtype=dtype)

return value

if device.type == "cuda" and torch.cuda.is_available():
model = model.to(device)
dummy_inputs = tree_map(
lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs
)
model.to(device)
dummy_inputs = tree_map(remap, dummy_inputs)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -542,6 +560,7 @@ def export_models(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
Expand All @@ -565,6 +584,8 @@ def export_models(
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`):
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
outputs from the ONNX configuration.
Expand Down Expand Up @@ -592,6 +613,7 @@ def export_models(
device=device,
input_shapes=input_shapes,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
dtype=dtype,
)
)

Expand All @@ -607,6 +629,7 @@ def export(
device: str = "cpu",
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation.
Expand All @@ -627,6 +650,8 @@ def export(
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`):
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from
Expand Down Expand Up @@ -661,7 +686,16 @@ def export(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
f" got: {torch.__version__}"
)
export_output = export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes)

torch_dtype = None
if dtype == "fp16":
torch_dtype = torch.float16
elif dtype is not None:
raise ValueError("Unsupported dtype, supported dtypes are: `torch.float16`.")

export_output = export_pytorch(
model, config, opset, output, device=device, input_shapes=input_shapes, dtype=torch_dtype
)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
if device == "cuda":
Expand All @@ -676,5 +710,5 @@ def export(
)

if not disable_dynamic_axes_fix:
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes)
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
return export_output
23 changes: 22 additions & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Utility functions."""

import copy
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import packaging
import torch
Expand Down Expand Up @@ -200,3 +200,24 @@ def recursive_to_device(value: Union[Tuple, List, "torch.Tensor"], device: str):
value = value.to(device)

return value


def recursive_to_dtype(
value: Union[Tuple, List, "torch.Tensor"], dtype: Optional[torch.dtype], start_dtype: Optional[torch.dtype] = None
):
if dtype is None:
return value

if isinstance(value, tuple):
value = list(value)
for i, val in enumerate(value):
value[i] = recursive_to_dtype(val, dtype)
value = tuple(value)
elif isinstance(value, list):
for i, val in enumerate(value):
value[i] = recursive_to_dtype(val, dtype)
elif isinstance(value, torch.Tensor):
if start_dtype is None or (start_dtype is not None and value.dtype == start_dtype):
value = value.to(dtype=dtype)

return value
6 changes: 6 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel, TFPreTrainedModel

from .base import ExportConfig
Expand Down Expand Up @@ -1017,6 +1018,7 @@ def get_model_from_task(
revision: Optional[str] = None,
framework: Optional[str] = None,
cache_dir: Optional[str] = None,
torch_dtype: Optional["torch.dtype"] = None,
**model_kwargs,
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Expand All @@ -1038,6 +1040,8 @@ def get_model_from_task(
none be provided.
cache_dir (`Optional[str]`, *optional*):
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
torch_dtype (`Optional[torch.dtype]`, defaults to `None`):
Data type to load the model on. PyTorch-only argument.
model_kwargs (`Dict[str, Any]`, *optional*):
Keyword arguments to pass to the model `.from_pretrained()` method.

Expand All @@ -1051,6 +1055,8 @@ def get_model_from_task(
model_class = TasksManager.get_model_class_for_task(task, framework)
kwargs = {"subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, **model_kwargs}
try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
Expand Down
36 changes: 34 additions & 2 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ def _onnx_export(
no_post_process: bool = False,
optimization_level: Optional[str] = None,
device: str = None,
fp16: bool = False,
):
with TemporaryDirectory() as tmpdir:
monolith = " --monolith " if monolith is True else " "
no_post_process = " --no-post-process " if no_post_process is True else " "
optimization_level = f" --optimize {optimization_level} " if optimization_level is not None else " "
task = f" --task {task} " if task is not None else " "
device = " --device cuda " if device == "cuda" else " "

command = f"python3 -m optimum.exporters.onnx --model {model_name}{monolith}{optimization_level}{device}{no_post_process}{task}{tmpdir}"
fp16 = " --fp16 --device cuda " if fp16 is True else " "
command = f"python3 -m optimum.exporters.onnx --model {model_name}{monolith}{fp16}{optimization_level}{device}{no_post_process}{task}{tmpdir}"
print("\nRUNNING:", command)
out = subprocess.run(
command,
Expand Down Expand Up @@ -250,3 +251,34 @@ def test_stable_diffusion(self):
shell=True,
check=True,
)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_vision
@require_torch_gpu
@slow
@pytest.mark.run_slow
def test_export_on_fp16(
self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool
):
# TODO: refer to https://github.com/pytorch/pytorch/issues/95377
if model_type == "yolos":
self.skipTest("yolos export on fp16 not supported due to a pytorch bug")

# TODO: refer to https://huggingface.slack.com/archives/C014N4749J9/p1677245766278129
if model_type == "deberta":
self.skipTest("deberta export on fp16 not supported due to a transformers bug")

# TODO: test once https://github.com/huggingface/transformers/pull/21789 is released
if (model_type == "vit" and task == "masked-im") or model_type == "vision-encoder-decoder":
self.skipTest(
"vit + masked-im, and vision-encoder-decoder export on fp16 not supported due to a transformers bug"
)

# TODO: test once https://github.com/huggingface/transformers/pull/21787 is released
if model_type == "perceiver" and task == "image-classification":
self.skipTest("perceiver + image-classification export on fp16 not supported due to a transformers bug")

if model_type == "ibert":
self.skipTest("ibert can not be supported in fp16")

self._onnx_export(model_name, task, monolith, no_post_process, fp16=True)