Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Awq] Enable the possibility to skip quantization for some target modules #27950

Merged
merged 11 commits into from
Dec 25, 2023
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops

# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl

# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3575,6 +3575,9 @@ def from_pretrained(
if quantization_config is None:
quantization_config = AwqConfig.from_dict(config.quantization_config)

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)

model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
)
Expand Down
19 changes: 19 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,10 @@ class AwqConfig(QuantizationConfigMixin):
The Maximum sequence length to generate when using fusing.
modules_to_fuse (`dict`, *optional*, default to `None`):
Overwrite the natively supported fusing scheme with the one specified by the users.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
"""

def __init__(
Expand All @@ -576,6 +580,7 @@ def __init__(
do_fuse: Optional[bool] = None,
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
Expand All @@ -586,6 +591,7 @@ def __init__(
self.version = version
self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert

self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
Expand Down Expand Up @@ -638,6 +644,19 @@ def post_init(self):
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

if self.modules_to_not_convert is not None:
awq_version_supports_non_conversion = False
MIN_AWQ_VERSION = "0.1.8"
if is_auto_awq_available():
awq_version_supports_non_conversion = version.parse(
importlib.metadata.version("autoawq")
) >= version.parse(MIN_AWQ_VERSION)

if not awq_version_supports_non_conversion:
raise ValueError(
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)

if self.do_fuse and self.modules_to_fuse is not None:
required_keys = [
"hidden_size",
Expand Down
19 changes: 19 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_from_dict(self):
class AwqTest(unittest.TestCase):
model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
dummy_transformers_model_name = "bigscience/bloom-560m"
model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj"

input_text = "Hello my name is"

Expand Down Expand Up @@ -223,6 +224,24 @@ def test_quantized_model_multi_gpu(self):

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_quantized_model_no_k_proj_quantized(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device)

quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device)

self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear))
self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))

EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to(
torch_device
)

output = quantized_model.generate(dummy_input, max_new_tokens=10)
self.assertEqual(EXPECTED_OUTPUT, output)


@slow
@require_torch_gpu
Expand Down
Loading