diff --git a/torch2trt/calibration.py b/torch2trt/calibration.py index 7506ea4e..b09b7193 100644 --- a/torch2trt/calibration.py +++ b/torch2trt/calibration.py @@ -51,7 +51,7 @@ def get_batch(self, *args, **kwargs): buffer[i].copy_(tensor) self.count += 1 - + return [int(buf.data_ptr()) for buf in self.buffers] else: return [] diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index e2aba5d4..58d84672 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -501,10 +501,12 @@ def torch2trt(module, strict_type_constraints=False, keep_network=True, int8_mode=False, + int8_calibrator=None, int8_calib_dataset=None, int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM, int8_calib_batch_size=1, use_onnx=False, + onnx_file_path=None, **kwargs): # capture arguments to provide to context @@ -518,7 +520,7 @@ def torch2trt(module, logger = trt.Logger(log_level) builder = trt.Builder(logger) config = builder.create_builder_config() - + if isinstance(inputs, list): inputs = tuple(inputs) if not isinstance(inputs, tuple): @@ -535,14 +537,24 @@ def torch2trt(module, output_names = default_output_names(len(outputs)) if use_onnx: - - f = io.BytesIO() - torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) - f.seek(0) - onnx_bytes = f.read() network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) - parser.parse(onnx_bytes) + if onnx_file_path is not None: + print('\tBeginning ONNX file parsing.. path = ', onnx_file_path) + with open(onnx_file_path, 'rb') as onnx_model_file: + onnx_model = onnx_model_file.read() + if not parser.parse(onnx_model): + raise RuntimeError("Onnx model parsing from {} failed. Error: {}".format(onnx_model_file, parser.get_error(0).desc())) + else: + parser.parse(onnx_model) + print('\tEND ONNX file parsing.') + else: + f = io.BytesIO() + torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names, opset_version=13) + f.seek(0) + onnx_bytes = f.read() + # parser = trt.OnnxParser(network, logger) + parser.parse(onnx_bytes) else: network = builder.create_network() @@ -567,15 +579,19 @@ def torch2trt(module, if int8_calib_dataset is None: int8_calib_dataset = TensorBatchDataset(inputs_in) + config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.INT8) - + #Making sure not to run calibration with QAT mode on if not 'qat_mode' in kwargs: # @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption - calibrator = DatasetCalibrator( - inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm - ) - config.int8_calibrator = calibrator + if int8_calibrator is None: + calibrator = DatasetCalibrator( + inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm + ) + config.int8_calibrator = calibrator + else: + config.int8_calibrator = int8_calibrator engine = builder.build_engine(network, config)