diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 732bed5a..17bceca5 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -537,7 +537,24 @@ def torch2trt(module, if use_onnx: f = io.BytesIO() - torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) + onnx_kwargs = kwargs.get('onnx_kwargs', {}) + opset_version = onnx_kwargs.get('opset_version', 9) + do_constant_folding = onnx_kwargs.get('do_constant_folding', True) + export_params = onnx_kwargs.get('export_params', True) + verbose_onnx = onnx_kwargs.get('verbose', log_level == trt.Logger.VERBOSE) + dynamic_axes = onnx_kwargs.get('dynamic_axes', None) + torch.onnx.export( + module, + inputs, + f, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + export_params=export_params, + verbose=verbose_onnx, + dynamic_axes=dynamic_axes, + ) f.seek(0) onnx_bytes = f.read() network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))