From 387663e571f0442feece6acf0c6016cc569cb4aa Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 21:22:49 +0800 Subject: [PATCH] Enable gptqmodel (#35012) * gptqmodel Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * update readme Signed-off-by: jiqing-feng * gptqmodel need use checkpoint_format (#1) * gptqmodel need use checkpoint_format * fix quantize * Update quantization_config.py * Update quantization_config.py * Update quantization_config.py --------- Co-authored-by: ZX-ModelCloud Co-authored-by: Qubitium-ModelCloud * Revert quantizer_gptq.py (#2) * revert quantizer_gptq.py change * pass **kwargs * limit gptqmodel and optimum version Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * fix warning Signed-off-by: jiqing-feng * fix version check Signed-off-by: jiqing-feng * revert unrelated changes Signed-off-by: jiqing-feng * enable gptqmodel tests Signed-off-by: jiqing-feng * fix requires gptq Signed-off-by: jiqing-feng * Fix Transformer compat (#3) * revert quantizer_gptq.py change * pass **kwargs * add meta info * cleanup * cleanup * Update quantization_config.py * hf_select_quant_linear pass checkpoint_format and meta * fix GPTQTestCUDA * Update test_gptq.py * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * cleanup * add backend * cleanup * cleanup * no need check exllama version * Update quantization_config.py * lower checkpoint_format and backend * check none * cleanup * Update quantization_config.py * fix self.use_exllama == False * spell * fix unittest * fix unittest --------- Co-authored-by: LRL Co-authored-by: Qubitium-ModelCloud * fix format Signed-off-by: jiqing-feng * fix format again Signed-off-by: jiqing-feng * update gptqmodel version (#6) * update gptqmodel version * update gptqmodel version * fix unit test (#5) * update gptqmodel version * update gptqmodel version * "not self.use_exllama" is not equivalent to "self.use_exllama==False" * fix unittest * update gptqmodel version * backend is loading_attibutes (#7) * fix format and tests Signed-off-by: jiqing-feng * fix memory check Signed-off-by: jiqing-feng * fix device mismatch Signed-off-by: jiqing-feng * fix result check Signed-off-by: jiqing-feng * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * update tests Signed-off-by: jiqing-feng * review: update docs (#10) * review: update docs (#12) * review: update docs * fix typo * update tests for gptqmodel Signed-off-by: jiqing-feng * update document (#9) * update overview.md * cleanup * Update overview.md * Update overview.md * Update overview.md * update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md --------- Co-authored-by: Qubitium-ModelCloud * typo * doc note for asymmetric quant * typo with apple silicon(e) * typo for marlin * column name revert: review * doc rocm support * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Signed-off-by: jiqing-feng Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> Co-authored-by: ZX-ModelCloud Co-authored-by: Qubitium-ModelCloud Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Co-authored-by: LRL Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/quantization/gptq.md | 52 +++++++- docs/source/en/quantization/overview.md | 56 +++++--- src/transformers/quantizers/quantizer_gptq.py | 43 ++++-- src/transformers/testing_utils.py | 7 +- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + src/transformers/utils/quantization_config.py | 59 +++++++-- tests/quantization/gptq/test_gptq.py | 125 +++++++++++++----- tests/utils/test_cache_utils.py | 4 +- 9 files changed, 266 insertions(+), 86 deletions(-) diff --git a/docs/source/en/quantization/gptq.md b/docs/source/en/quantization/gptq.md index 5713ef4132a9..1534a977f343 100644 --- a/docs/source/en/quantization/gptq.md +++ b/docs/source/en/quantization/gptq.md @@ -22,15 +22,42 @@ Try GPTQ quantization with PEFT in this [notebook](https://colab.research.google -The [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library implements the GPTQ algorithm, a post-training quantization technique where each row of the weight matrix is quantized independently to find a version of the weights that minimizes the error. These weights are quantized to int4, but they're restored to fp16 on the fly during inference. This can save your memory-usage by 4x because the int4 weights are dequantized in a fused kernel rather than a GPU's global memory, and you can also expect a speedup in inference because using a lower bitwidth takes less time to communicate. +Both [GPTQModel](https://github.com/ModelCloud/GPTQModel) and [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) libraries implement the GPTQ algorithm, a post-training quantization technique where each row of the weight matrix is quantized independently to find a version of the weights that minimizes error. These weights are quantized to int4, stored as int32 (int4 x 8) and dequantized (restored) to fp16 on the fly during inference. This can save memory by almost 4x because the int4 weights are often dequantized in a fused kernel. You can also expect a substantial speedup in inference due to lower bandwidth requirements for lower bitwidth. -Before you begin, make sure the following libraries are installed: +[GPTQModel](https://github.com/ModelCloud/GPTQModel) started as a maintained fork of AutoGPTQ but has since differentiated itself with the following major differences. + +* Model support: GPTQModel continues to support all of the latest LLM models. +* Multimodal support: GPTQModel supports accurate quantization of Qwen 2-VL and Ovis 1.6-VL image-to-text models. +* Platform support: Linux, macOS (Apple Silicon), and Windows 11. +* Hardware support: NVIDIA CUDA, AMD ROCm, Apple Silicon M1/MPS /CPU, Intel/AMD CPU, and Intel Datacenter Max/Arc GPUs. +* Asymmetric support: Asymmetric quantization can potentially introduce lower quantization errors compared to symmetric quantization. However, it is not backward compatible with AutoGPTQ, and not all kernels, such as Marlin, support asymmetric quantization. +* IPEX kernel for Intel/AMD accelerated CPU and Intel GPU (Datacenter Max/Arc GPUs) support. +* Updated Marlin kernel from Neural Magic optimized for A100 (Ampere). +* Updated kernels with auto-padding for legacy model support and models with non-uniform in/out-features. +* Faster quantization, lower memory usage, and more accurate default quantization via GPTQModel quantization APIs. +* User and developer friendly APIs. + + +[AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) will likely be deprecated in the future due the lack of continued support for new models and features. + +Before you begin, make sure the following libraries are installed and updated to the latest release: ```bash -pip install auto-gptq pip install --upgrade accelerate optimum transformers ``` +Then install either GPTQModel or AutoGPTQ. + +```bash +pip install gptqmodel --no-build-isolation +``` + +or + +```bash +pip install auto-gptq --no-build-isolation +``` + To quantize a model (currently only supported for text models), you need to create a [`GPTQConfig`] class and set the number of bits to quantize to, a dataset to calibrate the weights for quantization, and a tokenizer to prepare the dataset. ```py @@ -92,9 +119,22 @@ from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto") ``` +## Marlin + +[Marlin](https://github.com/IST-DASLab/marlin) is a 4-bit only CUDA GPTQ kernel, highly optimized for the NVIDIA A100 GPU (Ampere) architecture. Loading, dequantization, and execution of post-dequantized weights are highly parallelized, offering a substantial inference improvement versus the original CUDA GPTQ kernel. Marlin is only available for quantized inference and does not support model quantization. + +Marlin inference can be activated with the `backend` parameter in [`GPTQConfig`]. + +```py + +from transformers import AutoModelForCausalLM, GPTQConfig + +model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config=GPTQConfig(bits=4, backend="marlin")) +``` + ## ExLlama -[ExLlama](https://github.com/turboderp/exllama) is a Python/C++/CUDA implementation of the [Llama](model_doc/llama) model that is designed for faster inference with 4-bit GPTQ weights (check out these [benchmarks](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)). The ExLlama kernel is activated by default when you create a [`GPTQConfig`] object. To boost inference speed even further, use the [ExLlamaV2](https://github.com/turboderp/exllamav2) kernels by configuring the `exllama_config` parameter: +[ExLlama](https://github.com/turboderp/exllama) is a CUDA implementation of the [Llama](model_doc/llama) model that is designed for faster inference with 4-bit GPTQ weights (check out these [benchmarks](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)). The ExLlama kernel is activated by default when you create a [`GPTQConfig`] object. To boost inference speed even further, use the [ExLlamaV2](https://github.com/turboderp/exllamav2) kernels by configuring the `exllama_config` parameter: ```py import torch @@ -110,11 +150,11 @@ Only 4-bit models are supported, and we recommend deactivating the ExLlama kerne -The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2), then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. +The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ or GPTQModel, then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. ```py import torch from transformers import AutoModelForCausalLM, GPTQConfig gptq_config = GPTQConfig(bits=4, use_exllama=False) model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="cpu", quantization_config=gptq_config) -``` \ No newline at end of file +``` diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 48840fad646f..dfe680832b19 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -45,32 +45,50 @@ In short, supporting a wide range of quantization methods allows you to pick the Use the table below to help you decide which quantization method to use. -| Quantization method | On the fly quantization | CPU | CUDA GPU | RoCm GPU (AMD) | Metal (Apple Silicon) | Intel GPU | torch.compile() support | Number of bits | Supports fine-tuning (through PEFT) | Serializable with πŸ€— transformers | πŸ€— transformers support | Link to library | -|-------------------------------------|-------------------------|-----|----------|----------------|-----------------------|-----------|-------------------------|----------------|-------------------------------------|--------------|------------------------|---------------------------------------------| -| [AQLM](./aqlm) | πŸ”΄ | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 1 / 2 | 🟒 | 🟒 | 🟒 | https://github.com/Vahe1994/AQLM | -| [AWQ](./awq) | πŸ”΄ | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | ? | 4 | 🟒 | 🟒 | 🟒 | https://github.com/casper-hansen/AutoAWQ | -| [bitsandbytes](./bitsandbytes) | 🟒 | 🟑 * | 🟒 | 🟑 * | πŸ”΄ ** | 🟑 * | πŸ”΄ (soon!) | 4 / 8 | 🟒 | 🟒 | 🟒 | https://github.com/bitsandbytes-foundation/bitsandbytes | -| [compressed-tensors](./compressed_tensors) | πŸ”΄ | 🟒 | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 1 - 8 | 🟒 | 🟒 | 🟒 | https://github.com/neuralmagic/compressed-tensors | -| [EETQ](./eetq) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | ? | 8 | 🟒 | 🟒 | 🟒 | https://github.com/NetEase-FuXi/EETQ | -| GGUF / GGML (llama.cpp) | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | 1 - 8 | πŸ”΄ | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp | -| [GPTQ](./gptq) | πŸ”΄ | πŸ”΄ | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 2 - 3 - 4 - 8 | 🟒 | 🟒 | 🟒 | https://github.com/AutoGPTQ/AutoGPTQ | -| [HIGGS](./higgs) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 2 - 4 | πŸ”΄ | 🟒 | 🟒 | https://github.com/HanGuo97/flute | -| [HQQ](./hqq) | 🟒 | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 1 - 8 | 🟒 | πŸ”΄ | 🟒 | https://github.com/mobiusml/hqq/ | -| [optimum-quanto](./quanto) | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | 🟒 | 2 / 4 / 8 | πŸ”΄ | πŸ”΄ | 🟒 | https://github.com/huggingface/optimum-quanto | -| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | πŸ”΄ | 8 | πŸ”΄ | 🟒 | 🟒 | https://github.com/pytorch/FBGEMM | -| [torchao](./torchao.md) | 🟒 | | 🟒 | πŸ”΄ | partial support (int4 weight only) | πŸ”΄ | | 4 / 8 | | πŸŸ’πŸ”΄ | 🟒 | https://github.com/pytorch/ao | -| [VPTQ](./vptq) | πŸ”΄ | πŸ”΄ | 🟒 | 🟑 | πŸ”΄ | πŸ”΄ | 🟒 | 1 - 8 | πŸ”΄ | 🟒 | 🟒 | https://github.com/microsoft/VPTQ | +| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with πŸ€—Transformers | πŸ€—Transformers Support | Link to library | +|-----------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|---------------|------------------|-----------------------------|-------------------------|---------------------------------------------| +| [AQLM](./aqlm.md) | πŸ”΄ | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 1/2 | 🟒 | 🟒 | 🟒 | https://github.com/Vahe1994/AQLM | +| [AWQ](./awq.md) | πŸ”΄ | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | ? | 4 | 🟒 | 🟒 | 🟒 | https://github.com/casper-hansen/AutoAWQ | +| [bitsandbytes](./bitsandbytes.md) | 🟒 | 🟑 1 | 🟒 | 🟑 1 | πŸ”΄ 2 | 🟑 1 | πŸ”΄ 1 | 4/8 | 🟒 | 🟒 | 🟒 | https://github.com/bitsandbytes-foundation/bitsandbytes | +| [compressed-tensors](./compressed_tensors.md) | πŸ”΄ | 🟒 | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 1/8 | 🟒 | 🟒 | 🟒 | https://github.com/neuralmagic/compressed-tensors | +| [EETQ](./eetq.md) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | ? | 8 | 🟒 | 🟒 | 🟒 | https://github.com/NetEase-FuXi/EETQ | +| [GGUF / GGML (llama.cpp)](../gguf.md) | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | 1/8 | πŸ”΄ | [See Notes](../gguf.md) | [See Notes](../gguf.md) | https://github.com/ggerganov/llama.cpp | +| [GPTQModel](./gptq.md) | πŸ”΄ | 🟒 3 | 🟒 | 🟒 | 🟒 | 🟒 4 | πŸ”΄ | 2/3/4/8 | 🟒 | 🟒 | 🟒 | https://github.com/ModelCloud/GPTQModel | +| [AutoGPTQ](./gptq.md) | πŸ”΄ | πŸ”΄ | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 2/3/4/8 | 🟒 | 🟒 | 🟒 | https://github.com/AutoGPTQ/AutoGPTQ | +| [HIGGS](./higgs.md) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 2/4 | πŸ”΄ | 🟒 | 🟒 | https://github.com/HanGuo97/flute | +| [HQQ](./hqq.md) | 🟒 | 🟒 | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | 🟒 | 1/8 | 🟒 | πŸ”΄ | 🟒 | https://github.com/mobiusml/hqq/ | +| [optimum-quanto](./quanto.md) | 🟒 | 🟒 | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | 🟒 | 2/4/8 | πŸ”΄ | πŸ”΄ | 🟒 | https://github.com/huggingface/optimum-quanto | +| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟒 | πŸ”΄ | 🟒 | πŸ”΄ | πŸ”΄ | πŸ”΄ | πŸ”΄ | 8 | πŸ”΄ | 🟒 | 🟒 | https://github.com/pytorch/FBGEMM | +| [torchao](./torchao.md) | 🟒 | | 🟒 | πŸ”΄ | 🟑 5 | πŸ”΄ | | 4/8 | | πŸŸ’πŸ”΄ | 🟒 | https://github.com/pytorch/ao | +| [VPTQ](./vptq.md) | πŸ”΄ | πŸ”΄ | 🟒 | 🟑 | πŸ”΄ | πŸ”΄ | 🟒 | 1/8 | πŸ”΄ | 🟒 | 🟒 | https://github.com/microsoft/VPTQ | + +**1:** bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. -\* bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). + + + + +**2:** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships. + + + + -We value your feedback to help identify bugs before the full release! Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links. +**3:** GPTQModel[CPU] supports 4-bit via IPEX on Intel/AMD and full bit range via Torch on Intel/AMD/Apple Silicon. -\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships. +**4:** GPTQModel[Intel GPU] via IPEX only supports 4-bit for Intel Datacenter Max/Arc GPUs. + + + + + +**5:** torchao only supports int4 weight on Metal (Apple Silicon). + + - \ No newline at end of file diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py index d47a2ba79cb6..8c84c8817a44 100644 --- a/src/transformers/quantizers/quantizer_gptq.py +++ b/src/transformers/quantizers/quantizer_gptq.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging +from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin @@ -35,11 +35,11 @@ class GptqHfQuantizer(HfQuantizer): """ Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through - `auto_gptq` package. Quantization is done under the hood for users if they load a non-prequantized model. + `auto_gptq` or `gptqmodel` package. Quantization is done under the hood for users if they load a non-prequantized model. """ requires_calibration = False - required_packages = ["optimum", "auto_gptq"] + required_packages = ["optimum", "auto_gptq", "gptqmodel"] optimum_quantizer = None def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): @@ -54,19 +54,30 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): def validate_environment(self, *args, **kwargs): if not is_optimum_available(): raise ImportError("Loading a GPTQ quantized model requires optimum (`pip install optimum`)") + if is_auto_gptq_available() and is_gptqmodel_available(): + logger.warning("Detected gptqmodel and auto-gptq, will use gptqmodel") - if not is_auto_gptq_available(): - raise ImportError( - "Loading a GPTQ quantized model requires the auto-gptq library (`pip install auto-gptq`)" - ) - - gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + gptq_supports_cpu = ( + is_auto_gptq_available() + and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + ) or is_gptqmodel_available() if not gptq_supports_cpu and not torch.cuda.is_available(): raise RuntimeError("GPU is required to quantize or run quantize model.") - elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"): + elif not (is_auto_gptq_available() or is_gptqmodel_available()): raise ImportError( - "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`" + "Loading a GPTQ quantized model requires gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) library. " ) + elif is_auto_gptq_available() and version.parse(importlib.metadata.version("auto_gptq")) < version.parse( + "0.4.2" + ): + raise ImportError( + "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel>=1.4.3`." + ) + elif is_gptqmodel_available() and ( + version.parse(importlib.metadata.version("gptqmodel")) < version.parse("1.4.3") + or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.99") + ): + raise ImportError("The gptqmodel version should be >= 1.4.3, optimum version should >= 1.24.0") def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: @@ -76,12 +87,20 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.") return torch_dtype + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": torch.device("cpu")} + # Only with auto-gptq do not support CPU, we should move the model to cuda if available. + if not is_gptqmodel_available() and device_map in ("cpu", {"": torch.device("cpu")}): + device_map == {"": 0} + return device_map + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): if model.__class__.main_input_name != "input_ids": raise RuntimeError("We can only quantize pure text model.") if self.pre_quantized: - model = self.optimum_quantizer.convert_model(model) + model = self.optimum_quantizer.convert_model(model, **kwargs) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): if self.pre_quantized: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index aa26c2c3df86..5bbf27713805 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -85,6 +85,7 @@ is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_gptqmodel_available, is_grokadamw_available, is_hadamard_available, is_hqq_available, @@ -1207,11 +1208,13 @@ def require_tensorboard(test_case): return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard") -def require_auto_gptq(test_case): +def require_gptq(test_case): """ Decorator for auto_gptq dependency """ - return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) + return unittest.skipUnless( + is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq" + )(test_case) def require_hqq(test_case): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 8dce39fcf617..919f9f8bde03 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -147,6 +147,7 @@ is_g2p_en_available, is_galore_torch_available, is_gguf_available, + is_gptqmodel_available, is_grokadamw_available, is_hadamard_available, is_hqq_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 707903c702c7..802d0ca68974 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -144,6 +144,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _openai_available = _is_package_available("openai") _optimum_available = _is_package_available("optimum") _auto_gptq_available = _is_package_available("auto_gptq") +_gptqmodel_available = _is_package_available("gptqmodel") # `importlib.metadata.version` doesn't work with `awq` _auto_awq_available = importlib.util.find_spec("awq") is not None _quanto_available = _is_package_available("quanto") @@ -1041,6 +1042,10 @@ def is_auto_gptq_available(): return _auto_gptq_available +def is_gptqmodel_available(): + return _gptqmodel_available + + def is_eetq_available(): return _eetq_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 1883e93deb58..b356b3d9b0ab 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -25,7 +25,15 @@ from packaging import version -from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging +from ..utils import ( + is_auto_awq_available, + is_gptqmodel_available, + is_hqq_available, + is_torch_available, + is_torchao_available, + logging, +) +from .import_utils import is_auto_gptq_available if is_torch_available(): @@ -581,8 +589,16 @@ class GPTQConfig(QuantizationConfigMixin): Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers. + checkpoint_format (`str`, *optional*, defaults to `"gptq"`): + GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only. + meta (`Dict[str, any]`, *optional*): + Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. + i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] + backend (`str`, *optional*): + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only + valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py use_cuda_fp16 (`bool`, *optional*, defaults to `False`): - Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. + Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. Auto-gptq only. model_seqlen (`int`, *optional*): The maximum sequence length that the model can take. block_name_to_quantize (`str`, *optional*): @@ -622,6 +638,9 @@ def __init__( desc_act: bool = False, sym: bool = True, true_sequential: bool = True, + checkpoint_format: str = "gptq", + meta: Optional[Dict[str, any]] = None, + backend: Optional[str] = None, use_cuda_fp16: bool = False, model_seqlen: Optional[int] = None, block_name_to_quantize: Optional[str] = None, @@ -644,6 +663,9 @@ def __init__( self.desc_act = desc_act self.sym = sym self.true_sequential = true_sequential + self.checkpoint_format = checkpoint_format.lower() + self.meta = meta + self.backend = backend.lower() if isinstance(backend, str) else backend self.use_cuda_fp16 = use_cuda_fp16 self.model_seqlen = model_seqlen self.block_name_to_quantize = block_name_to_quantize @@ -660,7 +682,14 @@ def __init__( def get_loading_attributes(self): attibutes_dict = copy.deepcopy(self.__dict__) - loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"] + loading_attibutes = [ + "disable_exllama", + "use_exllama", + "exllama_config", + "use_cuda_fp16", + "max_input_length", + "backend", + ] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict @@ -692,6 +721,17 @@ def post_init(self): ['wikitext2','c4','c4-new'], but we found {self.dataset}""" ) + # make sure backend is back/forward compatible with both gptqmodel (full) and auto-gptq (partial) + if is_gptqmodel_available(): + # convert auto-gptq control into gptqmodel backend + if self.backend is None: + self.backend = "auto_trainable" if self.use_exllama is not None and not self.use_exllama else "auto" + else: + # convert gptqmodel backend `auto_trainable` into auto-gptq control + if self.backend == "auto_trainable": + self.use_exllama = False + + # auto-gptq specific kernel control logic if self.disable_exllama is None and self.use_exllama is None: # New default behaviour self.use_exllama = True @@ -725,12 +765,13 @@ def post_init(self): "speed using exllamav2 kernel by setting `exllama_config`." ) elif self.exllama_config["version"] == ExllamaVersion.TWO: - optimum_version = version.parse(importlib.metadata.version("optimum")) - autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) - if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): - raise ValueError( - f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" - ) + if is_auto_gptq_available(): + optimum_version = version.parse(importlib.metadata.version("optimum")) + autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) + if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): + raise ValueError( + f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" + ) if self.modules_in_block_to_quantize is not None: optimum_version = version.parse(importlib.metadata.version("optimum")) if optimum_version < version.parse("1.15.0"): diff --git a/tests/quantization/gptq/test_gptq.py b/tests/quantization/gptq/test_gptq.py index b1be9ac8c682..c0056b238663 100644 --- a/tests/quantization/gptq/test_gptq.py +++ b/tests/quantization/gptq/test_gptq.py @@ -18,16 +18,17 @@ import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GPTQConfig from transformers.testing_utils import ( is_torch_available, require_accelerate, - require_auto_gptq, + require_gptq, require_optimum, require_torch_gpu, require_torch_multi_gpu, slow, ) +from transformers.utils import is_auto_gptq_available, is_gptqmodel_available, is_ipex_available if is_torch_available(): @@ -76,25 +77,29 @@ def test_optimum_config(self): @slow @require_optimum -@require_auto_gptq -@require_torch_gpu +@require_gptq class GPTQTest(unittest.TestCase): model_name = "bigscience/bloom-560m" input_text = "Hello my name is" EXPECTED_OUTPUTS = set() + # flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.") EXPECTED_OUTPUTS.add("Hello my name is Alyson, I am a student in the") EXPECTED_OUTPUTS.add("Hello my name is Alyson and I am a very sweet,") + EXPECTED_OUTPUTS.add("Hello my name is Aiden, I am a student at the University") + EXPECTED_OUTPUTS.add("Hello my name is Nate and I am a member of the N") + EXPECTED_OUTPUTS.add("Hello my name is Nellie and I am a student at the") # this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings EXPECTED_RELATIVE_DIFFERENCE = 1.664253062 bits = 4 + sym = True group_size = 128 desc_act = False use_exllama = False @@ -103,7 +108,7 @@ class GPTQTest(unittest.TestCase): "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." ] - device_map = None + device_map = "cpu" if is_gptqmodel_available() else None # called only once for all test in this class @classmethod @@ -117,13 +122,15 @@ def setUpClass(cls): cls.mem_fp16 = cls.model_fp16.get_memory_footprint() cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + cls.config = AutoConfig.from_pretrained(cls.model_name) - quantization_config = GPTQConfig( + cls.quantization_config = GPTQConfig( bits=cls.bits, dataset=cls.dataset, tokenizer=cls.tokenizer, group_size=cls.group_size, desc_act=cls.desc_act, + sym=cls.sym, use_exllama=cls.use_exllama, ) @@ -131,7 +138,7 @@ def setUpClass(cls): cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map, - quantization_config=quantization_config, + quantization_config=cls.quantization_config, ) def test_memory_footprint(self): @@ -142,7 +149,7 @@ def test_memory_footprint(self): mem_quantized = self.quantized_model.get_memory_footprint() - self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE) + self.assertAlmostEqual(self.mem_fp16 / mem_quantized, self.EXPECTED_RELATIVE_DIFFERENCE, places=4) def test_device_and_dtype_assignment(self): r""" @@ -150,7 +157,7 @@ def test_device_and_dtype_assignment(self): Checks also if other models are casted correctly. """ # This should work - if self.device_map is None: + if self.device_map in (None, "cpu"): _ = self.quantized_model.to(0) with self.assertRaises(ValueError): @@ -170,16 +177,36 @@ def test_quantized_layers_class(self): Simple test to check if the model conversion has been done correctly by checking on the class type of the linear layers of the converted models """ - from auto_gptq.utils.import_utils import dynamically_import_QuantLinear - - QuantLinear = dynamically_import_QuantLinear( - use_triton=False, - desc_act=self.desc_act, - group_size=self.group_size, - bits=self.bits, - disable_exllama=not self.use_exllama, - disable_exllamav2=True, - ) + if is_gptqmodel_available(): + from gptqmodel.utils.importer import hf_select_quant_linear + + if hasattr(self.config, "quantization_config"): + checkpoint_format = self.config.quantization_config.get("checkpoint_format") + meta = self.config.quantization_config.get("meta") + else: + checkpoint_format = "gptq" + meta = None + QuantLinear = hf_select_quant_linear( + bits=self.bits, + group_size=self.group_size, + desc_act=self.desc_act, + sym=self.sym, + device_map=self.device_map, + checkpoint_format=checkpoint_format, + meta=meta, + backend=self.quantization_config.backend, + ) + elif is_auto_gptq_available(): + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear + + QuantLinear = hf_select_quant_linear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=not self.use_exllama, + disable_exllamav2=True, + ) self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear) def check_inference_correctness(self, model): @@ -192,7 +219,7 @@ def check_inference_correctness(self, model): encoded_input = self.tokenizer(self.input_text, return_tensors="pt") # Check the exactness of the results - output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(model.device), max_new_tokens=10) # Get the generation self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @@ -207,6 +234,8 @@ def test_generate_quality(self): if self.device_map is None: self.check_inference_correctness(self.quantized_model.to(0)) else: + if self.device_map == "cpu" and self.quantized_model.device.type != "cpu": + self.quantized_model.to("cpu") self.check_inference_correctness(self.quantized_model) def test_serialization(self): @@ -215,15 +244,28 @@ def test_serialization(self): """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - if not self.use_exllama: - quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( - tmpdirname, quantization_config=GPTQConfig(use_exllama=False, bits=4) - ).to(0) - self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old") + if is_auto_gptq_available() and not is_gptqmodel_available(): + quant_type = "cuda-old" if not self.use_exllama else "exllama" + if not self.use_exllama: + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( + tmpdirname, quantization_config=GPTQConfig(use_exllama=False, bits=4) + ) + if self.device_map != "cpu": + quantized_model_from_saved = quantized_model_from_saved.to(0) + else: + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( + tmpdirname, device_map=self.device_map + ) else: - # we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel - quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": 0}) - self.check_quantized_layers_type(quantized_model_from_saved, "exllama") + if self.device_map == "cpu": + quant_type = "ipex" if is_ipex_available() else "torch" + else: + quant_type = "exllama" + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( + tmpdirname, device_map=self.device_map + ) + + self.check_quantized_layers_type(quantized_model_from_saved, quant_type) self.check_inference_correctness(quantized_model_from_saved) @require_accelerate @@ -233,20 +275,26 @@ def test_serialization_big_model_inference(self): """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + device_map = self.device_map or "auto" + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=device_map) self.check_inference_correctness(quantized_model_from_saved) + +@require_torch_gpu +class GPTQTestCUDA(GPTQTest): + device_map = {"": 0} + def test_change_loading_attributes(self): """ Test the serialization of the model and the loading of the quantized weights works with another config file """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - if not self.use_exllama: + if is_auto_gptq_available() and not is_gptqmodel_available() and not self.use_exllama: self.check_quantized_layers_type(self.quantized_model, "cuda-old") # we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel quantized_model_from_saved = AutoModelForCausalLM.from_pretrained( - tmpdirname, quantization_config=GPTQConfig(use_exllama=True, bits=4), device_map={"": 0} + tmpdirname, quantization_config=GPTQConfig(use_exllama=True, bits=4), device_map=self.device_map ) self.assertEqual(quantized_model_from_saved.config.quantization_config.bits, self.bits) self.check_quantized_layers_type(quantized_model_from_saved, "exllama") @@ -255,20 +303,20 @@ def test_change_loading_attributes(self): @require_accelerate @require_torch_multi_gpu -class GPTQTestDeviceMap(GPTQTest): +class GPTQTestDeviceMap(GPTQTestCUDA): device_map = "auto" @require_accelerate @require_torch_multi_gpu -class GPTQTestDeviceMapExllama(GPTQTest): +class GPTQTestDeviceMapExllama(GPTQTestCUDA): device_map = "auto" use_exllama = True @slow @require_optimum -@require_auto_gptq +@require_gptq @require_torch_gpu @require_accelerate class GPTQTestActOrderExllama(unittest.TestCase): @@ -279,6 +327,7 @@ class GPTQTestActOrderExllama(unittest.TestCase): """ EXPECTED_OUTPUTS = set() + # flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions EXPECTED_OUTPUTS.add("Hello, how are you ? I'm doing good, thanks for asking.") # 4bit + act_order + 128g model_name = "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ" @@ -343,7 +392,7 @@ def test_max_input_length(self): @slow @require_optimum -@require_auto_gptq +@require_gptq @require_torch_gpu @require_accelerate class GPTQTestExllamaV2(unittest.TestCase): @@ -354,6 +403,7 @@ class GPTQTestExllamaV2(unittest.TestCase): """ EXPECTED_OUTPUTS = set() + # flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions EXPECTED_OUTPUTS.add("Hello, how are you ? I'm doing good, thanks for asking.") # 4bit + act_order + 128g model_name = "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ" @@ -374,7 +424,10 @@ def setUpClass(cls): cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) def test_quantized_layers_type(self): - self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllamav2") + self.assertEqual( + self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE, + "exllama" if is_gptqmodel_available() else "exllamav2", + ) def check_inference_correctness(self, model): """ diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 4a6dae67cbc8..053d2cf6397a 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -21,7 +21,7 @@ from transformers import set_seed from transformers.testing_utils import ( is_torch_available, - require_auto_gptq, + require_gptq, require_non_xpu, require_read_token, require_torch, @@ -319,7 +319,7 @@ def test_hybrid_cache_n_sequences(self): self.assertListEqual(decoded, expected_text) @require_non_xpu - @require_auto_gptq + @require_gptq def test_sink_cache_hard(self): tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")