-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix wrong quantization target in weight quantization #4038
Conversation
@@ -221,15 +220,13 @@ def quantize_input(self, *inputs, wrapper, **kwargs): | |||
self.record(wrapper, 'input', inputs) | |||
return inputs | |||
|
|||
def quantize_weight(self, wrapper, **kwargs): | |||
def quantize_weight(self, weight, wrapper, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain more about this change? weight
can be obtained from wrapper
, why pass it again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's about the code readability. Since we already pass the weight to be quantized (new_weight
) to quant_grad here, it is better to directly use it instead of obtaining it from wrapper. The developer can easily know that quantize_weight
is a big op to simulate quantization and the op takes origin/bn-folded weight as input. I think using wrapper.weight
will make it difficult to understand to structure of the training graph.
LGTM. Since this PR also changes |
Have added a ut about the interface of |
This pr contains two things:
module.weight
, QAT quantizer should quantizemodule.weight
instead ofmodule.old_weight