From 66e14b26b2c14cb2e88f2b2983eadc66023f1b10 Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Wed, 6 Apr 2022 13:57:35 +0300 Subject: [PATCH 1/6] add onnx kwargs --- torch2trt/torch2trt.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 732bed5a..5ade7284 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -537,7 +537,20 @@ def torch2trt(module, if use_onnx: f = io.BytesIO() - torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) + onnx_kwargs = kwargs.get('onnx_kwargs', {}) + opset_version = onnx_kwargs.get('opset_version', 9) + do_constant_folding = onnx_kwargs.get('do_constant_folding', True) + export_params = onnx_kwargs.get('export_params', False) + torch.onnx.export( + module, + inputs, + f, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + export_params=export_params, + ) f.seek(0) onnx_bytes = f.read() network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) From ffe8e6477768b339ba5f4289cab8244014023568 Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Wed, 6 Apr 2022 21:45:41 +0300 Subject: [PATCH 2/6] Update calibration.py --- torch2trt/calibration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/calibration.py b/torch2trt/calibration.py index 7506ea4e..6803399e 100644 --- a/torch2trt/calibration.py +++ b/torch2trt/calibration.py @@ -2,7 +2,7 @@ import tensorrt as trt -if trt.__version__ >= '5.1': +if str(trt.__version__) >= '5.1': DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 else: DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION @@ -66,4 +66,4 @@ def read_calibration_cache(self, *args, **kwargs): return None def write_calibration_cache(self, cache, *args, **kwargs): - pass \ No newline at end of file + pass From b295221aeaf44449ae1e1b6e62ee74d6c78ac7be Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Wed, 6 Apr 2022 21:51:30 +0300 Subject: [PATCH 3/6] fix trt version: TypeError: '<' not supported between instances of '_MockObject' and 'str' --- torch2trt/torch2trt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 5ade7284..fec80315 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -16,11 +16,11 @@ def trt_version(): - return trt.__version__ + return str(trt.__version__) def torch_version(): - return torch.__version__ + return str(torch.__version__) def torch_dtype_to_trt(dtype): From d0590d6034c5b892942f6b39059e0f021fe1cae1 Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Fri, 8 Apr 2022 15:15:26 +0300 Subject: [PATCH 4/6] trt version --- torch2trt/calibration.py | 42 ++++++++++++++++++++-------------------- torch2trt/torch2trt.py | 4 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torch2trt/calibration.py b/torch2trt/calibration.py index 6803399e..10a930b1 100644 --- a/torch2trt/calibration.py +++ b/torch2trt/calibration.py @@ -2,68 +2,68 @@ import tensorrt as trt -if str(trt.__version__) >= '5.1': +if trt.__version__ >= '5.1': DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 else: DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION - + class TensorBatchDataset(): - + def __init__(self, tensors): self.tensors = tensors - + def __len__(self): return len(self.tensors[0]) - + def __getitem__(self, idx): return [t[idx] for t in self.tensors] - - + + class DatasetCalibrator(trt.IInt8Calibrator): - + def __init__(self, inputs, dataset, batch_size=1, algorithm=DEFAULT_CALIBRATION_ALGORITHM): super(DatasetCalibrator, self).__init__() - + self.dataset = dataset self.batch_size = batch_size self.algorithm = algorithm - + # create buffers that will hold data batches self.buffers = [] for tensor in inputs: size = (batch_size,) + tuple(tensor.shape[1:]) buf = torch.zeros(size=size, dtype=tensor.dtype, device=tensor.device).contiguous() self.buffers.append(buf) - + self.count = 0 - + def get_batch(self, *args, **kwargs): if self.count < len(self.dataset): - + for i in range(self.batch_size): - + idx = self.count % len(self.dataset) # roll around if not multiple of dataset inputs = self.dataset[idx] - + # copy data for (input_idx, dataset_idx) into buffer for buffer, tensor in zip(self.buffers, inputs): buffer[i].copy_(tensor) - + self.count += 1 - + return [int(buf.data_ptr()) for buf in self.buffers] else: return [] - + def get_algorithm(self): return self.algorithm - + def get_batch_size(self): return self.batch_size - + def read_calibration_cache(self, *args, **kwargs): return None - + def write_calibration_cache(self, cache, *args, **kwargs): pass diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index fec80315..5ade7284 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -16,11 +16,11 @@ def trt_version(): - return str(trt.__version__) + return trt.__version__ def torch_version(): - return str(torch.__version__) + return torch.__version__ def torch_dtype_to_trt(dtype): From 8bf8cf4340490bdc373e01e964ec287d57d737d3 Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Fri, 8 Apr 2022 15:18:45 +0300 Subject: [PATCH 5/6] revert calibration.py --- torch2trt/calibration.py | 42 ++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/torch2trt/calibration.py b/torch2trt/calibration.py index 10a930b1..7506ea4e 100644 --- a/torch2trt/calibration.py +++ b/torch2trt/calibration.py @@ -6,64 +6,64 @@ DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 else: DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION - + class TensorBatchDataset(): - + def __init__(self, tensors): self.tensors = tensors - + def __len__(self): return len(self.tensors[0]) - + def __getitem__(self, idx): return [t[idx] for t in self.tensors] - - + + class DatasetCalibrator(trt.IInt8Calibrator): - + def __init__(self, inputs, dataset, batch_size=1, algorithm=DEFAULT_CALIBRATION_ALGORITHM): super(DatasetCalibrator, self).__init__() - + self.dataset = dataset self.batch_size = batch_size self.algorithm = algorithm - + # create buffers that will hold data batches self.buffers = [] for tensor in inputs: size = (batch_size,) + tuple(tensor.shape[1:]) buf = torch.zeros(size=size, dtype=tensor.dtype, device=tensor.device).contiguous() self.buffers.append(buf) - + self.count = 0 - + def get_batch(self, *args, **kwargs): if self.count < len(self.dataset): - + for i in range(self.batch_size): - + idx = self.count % len(self.dataset) # roll around if not multiple of dataset inputs = self.dataset[idx] - + # copy data for (input_idx, dataset_idx) into buffer for buffer, tensor in zip(self.buffers, inputs): buffer[i].copy_(tensor) - + self.count += 1 - + return [int(buf.data_ptr()) for buf in self.buffers] else: return [] - + def get_algorithm(self): return self.algorithm - + def get_batch_size(self): return self.batch_size - + def read_calibration_cache(self, *args, **kwargs): return None - + def write_calibration_cache(self, cache, *args, **kwargs): - pass + pass \ No newline at end of file From 463f2af659ce6223b0e48af45bf46034a7052b8a Mon Sep 17 00:00:00 2001 From: Lazaros Toumanidis Date: Fri, 8 Apr 2022 16:44:20 +0300 Subject: [PATCH 6/6] more onnx args --- torch2trt/torch2trt.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 5ade7284..17bceca5 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -540,7 +540,9 @@ def torch2trt(module, onnx_kwargs = kwargs.get('onnx_kwargs', {}) opset_version = onnx_kwargs.get('opset_version', 9) do_constant_folding = onnx_kwargs.get('do_constant_folding', True) - export_params = onnx_kwargs.get('export_params', False) + export_params = onnx_kwargs.get('export_params', True) + verbose_onnx = onnx_kwargs.get('verbose', log_level == trt.Logger.VERBOSE) + dynamic_axes = onnx_kwargs.get('dynamic_axes', None) torch.onnx.export( module, inputs, @@ -550,6 +552,8 @@ def torch2trt(module, opset_version=opset_version, do_constant_folding=do_constant_folding, export_params=export_params, + verbose=verbose_onnx, + dynamic_axes=dynamic_axes, ) f.seek(0) onnx_bytes = f.read()