Skip to content

Commit

Permalink
FEAT : Adding VPTQ quantization method to HFQuantizer (#34770)
Browse files Browse the repository at this point in the history
* init vptq

* add integration

* add vptq support

fix readme

* add tests && format

* format

* address comments

* format

* format

* address comments

* format

* address comments

* remove debug code

* Revert "remove debug code"

This reverts commit ed3b3ea.

* fix test

---------

Co-authored-by: Yang Wang <wyatuestc@gmail.com>
  • Loading branch information
wejoncy and YangWang92 authored Dec 20, 2024
1 parent 5a2aedc commit 4e27a40
Show file tree
Hide file tree
Showing 21 changed files with 647 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docker/transformers-quantization-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/pef
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2

# Add vptq for quantization testing
RUN python3 -m pip install --no-cache-dir vptq

# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq

Expand Down
2 changes: 2 additions & 0 deletions docs/source/ar/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
# title: AWQ
# - local: quantization/aqlm
# title: AQLM
# - local: quantization/vptq
# title: VPTQ
# - local: quantization/quanto
# title: Quanto
# - local: quantization/eetq
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@
title: AWQ
- local: quantization/aqlm
title: AQLM
- local: quantization/vptq
title: VPTQ
- local: quantization/quanto
title: Quanto
- local: quantization/eetq
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
Quantization reduces the size of the LLM weights by storing them in a lower precision. This translates to lower memory usage and makes loading LLMs for inference more accessible if you're constrained by your GPUs memory. If you aren't limited by your GPU, you don't necessarily need to quantize your model because it can incur a small latency cost (except for AWQ and fused AWQ modules) due to the extra step required to quantize and dequantize the weights.

> [!TIP]
> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
> There are many quantization libraries (see the [Quantization](./quantization) guide for more details) available, such as Quanto, AQLM, VPTQ, AWQ, and AutoGPTQ. Feel free to try them out and see which one works best for your use case. We also recommend reading the [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) blog post which compares AutoGPTQ and bitsandbytes.
Use the Model Memory Calculator below to estimate and compare how much memory is required to load a model. For example, try estimating how much memory it costs to load [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).

Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] AqlmConfig

## VptqConfig

[[autodoc]] VptqConfig

## AwqConfig

[[autodoc]] AwqConfig
Expand Down
3 changes: 2 additions & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Use the table below to help you decide which quantization method to use.
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | partial support (int4 weight only) | 🔴 | | 4 / 8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |

<Tip>

Expand All @@ -71,4 +72,4 @@ We value your feedback to help identify bugs before the full release! Check out

\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships.

</Tip>
</Tip>
111 changes: 111 additions & 0 deletions docs/source/en/quantization/vptq.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/source/ko/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@
title: AWQ
- local: in_translation
title: (번역중) AQLM
- local: in_translation
title: (번역중) VPTQ
- local: in_translation
title: (번역중) Quanto
- local: in_translation
Expand All @@ -173,6 +175,8 @@
title: (번역중) AWQ
- local: in_translation
title: (번역중) AQLM
- local: in_translation
title: (번역중) VPTQ
- local: quantization/quanto
title: Quanto
- local: quantization/eetq
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
양자화는 LLM 가중치를 더 낮은 정밀도로 저장하여 크기를 줄입니다. 이는 메모리 사용량을 줄이며 GPU 메모리에 제약이 있는 경우 추론을 위해 LLM을 로드하는 것을 더 용이하게 합니다. GPU가 충분하다면, 모델을 양자화할 필요는 없습니다. 추가적인 양자화 및 양자화 해제 단계로 인해 약간의 지연이 발생할 수 있기 때문입니다(AWQ 및 융합 AWQ 모듈 제외).

> [!TIP]
> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
> 다양한 양자화 라이브러리(자세한 내용은 [Quantization](./quantization) 가이드를 참조하십시오)가 있습니다. 여기에는 Quanto, AQLM, VPTQ, AWQ 및 AutoGPTQ가 포함됩니다. 사용 사례에 가장 잘 맞는 라이브러리를 사용해 보십시오. 또한 AutoGPTQ와 bitsandbytes를 비교하는 [Overview of natively supported quantization schemes in 🤗 Transformers](https://hf.co/blog/overview-quantization-transformers) 블로그 게시물을 읽어보는 것을 추천합니다.
아래의 모델 메모리 계산기를 사용하여 모델을 로드하는 데 필요한 메모리를 추정하고 비교해 보십시오. 예를 들어 [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)를 로드하는 데 필요한 메모리를 추정해 보십시오.

Expand Down
4 changes: 4 additions & 0 deletions docs/source/ko/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ Transformers에서 지원되지 않는 양자화 기법들은 [`HfQuantizer`]

[[autodoc]] AqlmConfig

## VptqConfig[[transformers.VptqConfig]]

[[autodoc]] VptqConfig

## AwqConfig[[transformers.AwqConfig]]

[[autodoc]] AwqConfig
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,7 @@
"HqqConfig",
"QuantoConfig",
"TorchAoConfig",
"VptqConfig",
],
}

Expand Down Expand Up @@ -6017,6 +6018,7 @@
HqqConfig,
QuantoConfig,
TorchAoConfig,
VptqConfig,
)

try:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
"vptq": ["replace_with_vptq_linear"],
}

try:
Expand Down Expand Up @@ -207,6 +208,7 @@
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
from .vptq import replace_with_vptq_linear

try:
if not is_torch_available():
Expand Down
101 changes: 101 additions & 0 deletions src/transformers/integrations/vptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"VPTQ (Vector Post-Training Quantization) integration file"

import torch.nn as nn
from accelerate import init_empty_weights
from vptq import VQuantLinear


def replace_with_vptq_linear(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with VPTQ quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`VptqConfig`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `VQuantLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""

modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
layer_name = ".".join(current_key_name)
shared_layer_config = quantization_config.shared_layer_config
config_for_layers = quantization_config.config_for_layers

if (
isinstance(module, nn.Linear)
and layer_name not in modules_to_not_convert
and ((layer_name in config_for_layers) or (current_key_name[-1] in shared_layer_config))
):
layer_params = config_for_layers.get(layer_name, None) or shared_layer_config.get(
current_key_name[-1], None
)

with init_empty_weights():
in_features = module.in_features
out_features = module.out_features

model._modules[name] = VQuantLinear(
in_features,
out_features,
vector_lens=layer_params["vector_lens"],
num_centroids=layer_params["num_centroids"],
num_res_centroids=layer_params["num_res_centroids"],
group_num=layer_params["group_num"],
group_size=layer_params["group_size"],
outlier_size=layer_params["outlier_size"],
indices_as_float=layer_params["indices_as_float"],
enable_norm=layer_params["enable_norm"],
enable_perm=layer_params["enable_perm"],
is_indice_packed=True,
enable_proxy_error=False,
bias=module.bias is not None,
)
has_been_replaced = True

# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_vptq_linear(
module,
quantization_config=quantization_config,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
VptqConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
Expand All @@ -42,6 +43,7 @@
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer


AUTO_QUANTIZER_MAPPING = {
Expand All @@ -57,6 +59,7 @@
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
"vptq": VptqHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -72,6 +75,7 @@
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
"vptq": VptqConfig,
}


Expand Down
98 changes: 98 additions & 0 deletions src/transformers/quantizers/quantizer_vptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional

from .base import HfQuantizer


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_torch_available, is_vptq_available, logging
from ..utils.quantization_config import QuantizationConfigMixin


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


class VptqHfQuantizer(HfQuantizer):
"""
Quantizer of the VPTQ method. Enables the loading of prequantized models.
"""

requires_calibration = True
required_packages = ["vptq"]

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError("Using `vptq` quantization requires Accelerate: `pip install accelerate`")

if not is_vptq_available():
raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
if torch.cuda.is_available():
torch_dtype = torch.float16
logger.info(
"CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
)
else:
import vptq

device_availability = getattr(vptq, "device_availability", lambda device: False)
if device_availability("cpu") is True:
raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
torch_dtype = torch.float32
logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
return torch_dtype

def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
"""
we don't have param like modules_to_not_convert to indicate which layers should not be quantized
because `quantization_config` include the layers that should be quantized
"""
from ..integrations import replace_with_vptq_linear

modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
self.quantization_config.modules_to_not_convert or []
)

replace_with_vptq_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
return model

@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False

def is_serializable(self, safe_serialization=None):
return True
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
is_torchdynamo_available,
is_torchvision_available,
is_vision_available,
is_vptq_available,
strtobool,
)

Expand Down Expand Up @@ -1142,6 +1143,13 @@ def require_aqlm(test_case):
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


def require_vptq(test_case):
"""
Decorator marking a test that requires vptq
"""
return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)


def require_eetq(test_case):
"""
Decorator marking a test that requires eetq
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
is_training_run_on_sagemaker,
is_uroman_available,
is_vision_available,
is_vptq_available,
requires_backends,
torch_only_method,
)
Expand Down
Loading

0 comments on commit 4e27a40

Please sign in to comment.