Skip to content

Commit

Permalink
add the option to pass an ONNX file path if there is one (instead of …
Browse files Browse the repository at this point in the history
…generating on the go)

add the option to override the default int8 calibrator with a custom calibrator instance
  • Loading branch information
Ibrahim Abedrabbo committed May 25, 2022
1 parent 441c46a commit 3828d82
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion torch2trt/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_batch(self, *args, **kwargs):
buffer[i].copy_(tensor)

self.count += 1

return [int(buf.data_ptr()) for buf in self.buffers]
else:
return []
Expand Down
40 changes: 28 additions & 12 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,12 @@ def torch2trt(module,
strict_type_constraints=False,
keep_network=True,
int8_mode=False,
int8_calibrator=None,
int8_calib_dataset=None,
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
int8_calib_batch_size=1,
use_onnx=False,
onnx_file_path=None,
**kwargs):

# capture arguments to provide to context
Expand All @@ -518,7 +520,7 @@ def torch2trt(module,
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
config = builder.create_builder_config()

if isinstance(inputs, list):
inputs = tuple(inputs)
if not isinstance(inputs, tuple):
Expand All @@ -535,14 +537,24 @@ def torch2trt(module,
output_names = default_output_names(len(outputs))

if use_onnx:

f = io.BytesIO()
torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names)
f.seek(0)
onnx_bytes = f.read()
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse(onnx_bytes)
if onnx_file_path is not None:
print('\tBeginning ONNX file parsing.. path = ', onnx_file_path)
with open(onnx_file_path, 'rb') as onnx_model_file:
onnx_model = onnx_model_file.read()
if not parser.parse(onnx_model):
raise RuntimeError("Onnx model parsing from {} failed. Error: {}".format(onnx_model_file, parser.get_error(0).desc()))
else:
parser.parse(onnx_model)
print('\tEND ONNX file parsing.')
else:
f = io.BytesIO()
torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names, opset_version=13)
f.seek(0)
onnx_bytes = f.read()
# parser = trt.OnnxParser(network, logger)
parser.parse(onnx_bytes)

else:
network = builder.create_network()
Expand All @@ -567,15 +579,19 @@ def torch2trt(module,
if int8_calib_dataset is None:
int8_calib_dataset = TensorBatchDataset(inputs_in)

config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)

#Making sure not to run calibration with QAT mode on
if not 'qat_mode' in kwargs:
# @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption
calibrator = DatasetCalibrator(
inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm
)
config.int8_calibrator = calibrator
if int8_calibrator is None:
calibrator = DatasetCalibrator(
inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm
)
config.int8_calibrator = calibrator
else:
config.int8_calibrator = int8_calibrator

engine = builder.build_engine(network, config)

Expand Down

0 comments on commit 3828d82

Please sign in to comment.