Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 [Quantization] Store the original dtype in the config as a private attribute 🚨🚨🚨 #26761

Merged
merged 8 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
else self.quantization_config
)

# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
Comment on lines +857 to +858
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we pop it because it should not be saved no?


self.dict_torch_dtype_to_str(serializable_config_dict)

if "_flash_attn_2_enabled" in serializable_config_dict:
Expand Down Expand Up @@ -896,6 +899,9 @@ def to_dict(self) -> Dict[str, Any]:
else self.quantization_config
)

# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = output.pop("_pre_quantization_dtype", None)

self.dict_torch_dtype_to_str(output)

return output
Expand Down
27 changes: 25 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2177,8 +2177,25 @@ def to(self, *args, **kwargs):
"`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
return super().to(*args, **kwargs)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
# For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = False

if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
else:
dtype_present_in_args = True

if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
" `dtype` by passing the correct `torch_dtype` argument."
)
return super().to(*args, **kwargs)

def half(self, *args):
# Checks if the model is quantized
Expand Down Expand Up @@ -3164,6 +3181,12 @@ def from_pretrained(
if hasattr(model, "quantization_method"):
model.is_quantized = True

# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
Comment on lines +3186 to +3187
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be needed in the quantizer config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok since users can always load back quantized models with new torch_dtype making that _pre_quantization_dtype obsolete

config._pre_quantization_dtype = torch_dtype

if isinstance(device_map, str):
special_dtypes = {}
if load_in_8bit or load_in_4bit:
Expand Down
8 changes: 8 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def test_memory_footprint(self):
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit)

def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.model_4bit.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.model_4bit.config._pre_quantization_dtype == torch.float16)

def test_linear_are_4bit(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
Expand Down
8 changes: 8 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ def test_quantization_config_json_serialization(self):

_ = config.to_json_string()

def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.model_8bit.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.model_8bit.config._pre_quantization_dtype == torch.float16)

def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
Expand Down
20 changes: 20 additions & 0 deletions tests/quantization/gptq/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ def test_memory_footprint(self):

self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Checks also if other models are casted correctly.
"""
# This should work
_ = self.quantized_model.to(0)

with self.assertRaises(ValueError):
# Tries with a `dtype``
self.quantized_model.to(torch.float16)

def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
"""
self.assertTrue(hasattr(self.quantized_model.config, "_pre_quantization_dtype"))
self.assertFalse(hasattr(self.model_fp16.config, "_pre_quantization_dtype"))
self.assertTrue(self.quantized_model.config._pre_quantization_dtype == torch.float16)

def test_quantized_layers_class(self):
"""
Simple test to check if the model conversion has been done correctly by checking on
Expand Down