Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display warning for unknown quants config instead of an error #35963

Merged
merged 10 commits into from
Feb 4, 2025
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3634,7 +3634,10 @@ def from_pretrained(

model_kwargs = kwargs

pre_quantized = getattr(config, "quantization_config", None) is not None
pre_quantized = getattr(
config, "quantization_config", None
) is not None and AutoHfQuantizer.supports_quant_method(config.quantization_config)

if pre_quantized or quantization_config is not None:
if pre_quantized:
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
Expand All @@ -3647,7 +3650,6 @@ def from_pretrained(
config.quantization_config,
pre_quantized=pre_quantized,
)

else:
hf_quantizer = None

Expand Down
24 changes: 24 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Dict, Optional, Union

from ..models.auto.configuration_auto import AutoConfig
from ..utils import logging
from ..utils.quantization_config import (
AqlmConfig,
AwqConfig,
Expand Down Expand Up @@ -82,6 +83,8 @@
"vptq": VptqConfig,
}

logger = logging.get_logger(__name__)


class AutoQuantizationConfig:
"""
Expand Down Expand Up @@ -195,3 +198,24 @@ def merge_quantization_configs(
warnings.warn(warning_msg)

return quantization_config

@staticmethod
def supports_quant_method(quantization_config_dict):
quant_method = quantization_config_dict.get("quant_method", None)
# We need a special care for bnb models to make sure everything is BC ..
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
elif quant_method is None:
raise ValueError(
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
)

if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
logger.warning(
f"Unknown quantization type, got {quant_method} - supported types are:"
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
"To remove the warning, you can delete the quantization_config attribute in config.json"
)
return False
return True