Skip to content

Commit

Permalink
expose the batch size of INT8 calibration as parameter, since different
Browse files Browse the repository at this point in the history
size may generate different accuracy loss.
  • Loading branch information
Chujingjun committed Sep 2, 2020
1 parent 63895f0 commit 3ddd70d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def torch2trt(module,
int8_mode=False,
int8_calib_dataset=None,
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
int8_calib_batch_size=1,
use_onnx=False):

inputs_in = inputs
Expand Down Expand Up @@ -454,7 +455,7 @@ def torch2trt(module,

# @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption
builder.int8_calibrator = DatasetCalibrator(
inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm
inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm
)

engine = builder.build_cuda_engine(network)
Expand Down

0 comments on commit 3ddd70d

Please sign in to comment.