diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 7ab236a55d59..0d694eaa72d6 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu RUN python3 -m pip install --no-cache-dir einops # Add autoawq for quantization testing -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl # For bettertransformer + gptq RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e792c076b7a7..6cb1c9f16f50 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3575,6 +3575,9 @@ def from_pretrained( if quantization_config is None: quantization_config = AwqConfig.from_dict(config.quantization_config) + if quantization_config.modules_to_not_convert is not None: + modules_to_not_convert.extend(quantization_config.modules_to_not_convert) + model, has_been_replaced = replace_with_awq_linear( model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert ) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 0166cb4bf875..3684dcc76fce 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -564,6 +564,10 @@ class AwqConfig(QuantizationConfigMixin): The Maximum sequence length to generate when using fusing. modules_to_fuse (`dict`, *optional*, default to `None`): Overwrite the natively supported fusing scheme with the one specified by the users. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models. """ def __init__( @@ -576,6 +580,7 @@ def __init__( do_fuse: Optional[bool] = None, fuse_max_seq_len: Optional[int] = None, modules_to_fuse: Optional[dict] = None, + modules_to_not_convert: Optional[List] = None, **kwargs, ): self.quant_method = QuantizationMethod.AWQ @@ -586,6 +591,7 @@ def __init__( self.version = version self.backend = backend self.fuse_max_seq_len = fuse_max_seq_len + self.modules_to_not_convert = modules_to_not_convert self.modules_to_fuse = modules_to_fuse if do_fuse is None: @@ -638,6 +644,19 @@ def post_init(self): f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." ) + if self.modules_to_not_convert is not None: + awq_version_supports_non_conversion = False + MIN_AWQ_VERSION = "0.1.8" + if is_auto_awq_available(): + awq_version_supports_non_conversion = version.parse( + importlib.metadata.version("autoawq") + ) >= version.parse(MIN_AWQ_VERSION) + + if not awq_version_supports_non_conversion: + raise ValueError( + f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + if self.do_fuse and self.modules_to_fuse is not None: required_keys = [ "hidden_size", diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 8f9cbd91aad7..3f5118635ac6 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -88,6 +88,7 @@ def test_from_dict(self): class AwqTest(unittest.TestCase): model_name = "TheBloke/Mistral-7B-v0.1-AWQ" dummy_transformers_model_name = "bigscience/bloom-560m" + model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj" input_text = "Hello my name is" @@ -223,6 +224,24 @@ def test_quantized_model_multi_gpu(self): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + def test_quantized_model_no_k_proj_quantized(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + """ + dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device) + + self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear)) + self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear)) + + EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to( + torch_device + ) + + output = quantized_model.generate(dummy_input, max_new_tokens=10) + self.assertTrue((EXPECTED_OUTPUT == output).all()) + @slow @require_torch_gpu