Skip to content

Commit

Permalink
Display warning for unknown quants config instead of an error (huggin…
Browse files Browse the repository at this point in the history
…gface#35963)

* add supports_quant_method check

* fix

* add test and fix suggestions

* change logic slightly

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
  • Loading branch information
2 people authored and elvircrn committed Feb 13, 2025
1 parent e938276 commit 849e1f3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3634,7 +3634,10 @@ def from_pretrained(

model_kwargs = kwargs

pre_quantized = getattr(config, "quantization_config", None) is not None
pre_quantized = hasattr(config, "quantization_config")
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
pre_quantized = False

if pre_quantized or quantization_config is not None:
if pre_quantized:
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
Expand All @@ -3647,7 +3650,6 @@ def from_pretrained(
config.quantization_config,
pre_quantized=pre_quantized,
)

else:
hf_quantizer = None

Expand Down
23 changes: 23 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Dict, Optional, Union

from ..models.auto.configuration_auto import AutoConfig
from ..utils import logging
from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
Expand Down Expand Up @@ -86,6 +87,8 @@
"spqr": SpQRConfig,
}

logger = logging.get_logger(__name__)


class AutoQuantizationConfig:
"""
Expand Down Expand Up @@ -199,3 +202,23 @@ def merge_quantization_configs(
warnings.warn(warning_msg)

return quantization_config

@staticmethod
def supports_quant_method(quantization_config_dict):
quant_method = quantization_config_dict.get("quant_method", None)
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
elif quant_method is None:
raise ValueError(
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
)

if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
logger.warning(
f"Unknown quantization type, got {quant_method} - supported types are:"
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
"To remove the warning, you can delete the quantization_config attribute in config.json"
)
return False
return True
13 changes: 13 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,19 @@ def test_cache_when_needed_at_train_time(self):
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)

def test_unknown_quantization_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
config.quantization_config = {"quant_method": "unknown"}
model.save_pretrained(tmpdir)
with self.assertLogs("transformers", level="WARNING") as cm:
BertModel.from_pretrained(tmpdir)
self.assertEqual(len(cm.records), 1)
self.assertTrue(cm.records[0].message.startswith("Unknown quantization type, got"))


@slow
@require_torch
Expand Down

0 comments on commit 849e1f3

Please sign in to comment.