Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix potential bug in quantize device #3212

Merged
merged 1 commit into from
Dec 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,10 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
if rmin.is_cuda:
rmin = torch.min(rmin, torch.Tensor([0]).cuda())
rmax = torch.max(rmax, torch.Tensor([0]).cuda())
qmin = torch.Tensor([0]).cuda()
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device))
rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device))
qmin = torch.Tensor([0]).to(rmin.device)
qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device)

# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)
Expand All @@ -103,7 +97,6 @@ def update_quantization_param(bits, rmin, rmax):
initial_zero_point = qmin - rmin / scale

# Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin:
nudged_zero_point = qmin
elif initial_zero_point > qmax:
Expand Down Expand Up @@ -199,10 +192,8 @@ def _quantize(self, bits, op, real_val):
-------
Tensor
"""
if real_val.is_cuda:
op.zero_point = op.zero_point.cuda()
op.scale = op.scale.cuda()

op.zero_point = op.zero_point.to(real_val.device)
op.scale = op.scale.to(real_val.device)
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
Expand Down