Skip to content

Commit

Permalink
fix data device type bug (microsoft#3856)
Browse files Browse the repository at this point in the history
  • Loading branch information
linbinskn authored Jun 22, 2021
1 parent 009722a commit 27e123d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/en_US/Compression/QuantizationSpeedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ After getting mixed precision engine, users can do inference with input data.
Note


* Recommend using "cpu"(host) as data device(for both inference data and calibration data) since data should be on host initially and it will be transposed to device before inference. If data type is not "cpu"(host), this tool will transpose it to "cpu" which may increases unnecessary overhead.
* User can also do post-training quantization leveraging TensorRT directly(need to provide calibration dataset).
* Not all op types are supported right now. At present, NNI supports Conv, Linear, Relu and MaxPool. More op types will be supported in the following release.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def _tensorrt_build_withcalib(self, onnx_path):
calib_data_set.append(data)
calib_data = np.concatenate(calib_data_set)
elif type(self.calib_data_loader) == torch.Tensor:
# trt need numpy as calibration data, only cpu data can convert to numpy directly
if self.calib_data_loader.device != torch.device("cpu"):
self.calib_data_loader = self.calib_data_loader.to("cpu")
calib_data = self.calib_data_loader.numpy()
else:
raise ValueError("Not support calibration datatype")
Expand Down Expand Up @@ -326,6 +329,8 @@ def inference(self, test_data):
Model input tensor
"""
# convert pytorch tensor to numpy darray
if test_data.device != torch.device("cpu"):
test_data = test_data.to("cpu")
test_data = test_data.numpy()
# Numpy dtype should be float32
assert test_data.dtype == np.float32
Expand Down

0 comments on commit 27e123d

Please sign in to comment.