-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨🚨🚨 [Quantization
] Store the original dtype in the config as a private attribute 🚨🚨🚨
#26761
Changes from all commits
2797129
b7e797f
73d3109
28d2b27
316f776
948de7e
06700c7
1ebe674
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2177,8 +2177,25 @@ def to(self, *args, **kwargs): | |
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" | ||
" model has already been set to the correct devices and casted to the correct `dtype`." | ||
) | ||
else: | ||
return super().to(*args, **kwargs) | ||
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: | ||
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. | ||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`. | ||
dtype_present_in_args = False | ||
|
||
if "dtype" not in kwargs: | ||
for arg in args: | ||
if isinstance(arg, torch.dtype): | ||
dtype_present_in_args = True | ||
break | ||
else: | ||
dtype_present_in_args = True | ||
|
||
if dtype_present_in_args: | ||
raise ValueError( | ||
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" | ||
" `dtype` by passing the correct `torch_dtype` argument." | ||
) | ||
return super().to(*args, **kwargs) | ||
|
||
def half(self, *args): | ||
# Checks if the model is quantized | ||
|
@@ -3164,6 +3181,12 @@ def from_pretrained( | |
if hasattr(model, "quantization_method"): | ||
model.is_quantized = True | ||
|
||
# We store the original dtype for quantized models as we cannot easily retrieve it | ||
# once the weights have been quantized | ||
# Note that once you have loaded a quantized model, you can't change its dtype so this will | ||
# remain a single source of truth | ||
Comment on lines
+3186
to
+3187
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be needed in the quantizer config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is ok since users can always load back quantized models with new torch_dtype making that |
||
config._pre_quantization_dtype = torch_dtype | ||
|
||
if isinstance(device_map, str): | ||
special_dtypes = {} | ||
if load_in_8bit or load_in_4bit: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we pop it because it should not be saved no?