From bbc68feea9aa40da5302858b21bfbb767c5b69e4 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 13:23:33 +0000 Subject: [PATCH 01/73] update HQQ transformers integration --- .../Dockerfile | 3 + docs/source/en/quantization.md | 40 ++++ docs/source/en/quicktour.md | 0 src/transformers/__init__.py | 1 + src/transformers/integrations/__init__.py | 1 + src/transformers/integrations/hqq.py | 103 +++++++++++ .../integrations/integration_utils.py | 0 src/transformers/modeling_utils.py | 13 +- src/transformers/quantizers/__init__.py | 0 src/transformers/quantizers/auto.py | 4 + src/transformers/quantizers/quantizer_hqq.py | 174 ++++++++++++++++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/hqq_utils.py | 98 ++++++++++ src/transformers/utils/quantization_config.py | 122 +++++++++++- tests/quantization/hqq/test_hqq.py | 172 +++++++++++++++++ 15 files changed, 730 insertions(+), 2 deletions(-) mode change 100644 => 100755 docker/transformers-quantization-latest-gpu/Dockerfile mode change 100644 => 100755 docs/source/en/quantization.md mode change 100644 => 100755 docs/source/en/quicktour.md mode change 100644 => 100755 src/transformers/__init__.py mode change 100644 => 100755 src/transformers/integrations/__init__.py create mode 100755 src/transformers/integrations/hqq.py mode change 100644 => 100755 src/transformers/integrations/integration_utils.py mode change 100644 => 100755 src/transformers/modeling_utils.py mode change 100644 => 100755 src/transformers/quantizers/__init__.py mode change 100644 => 100755 src/transformers/quantizers/auto.py create mode 100755 src/transformers/quantizers/quantizer_hqq.py mode change 100644 => 100755 src/transformers/utils/__init__.py create mode 100755 src/transformers/utils/hqq_utils.py mode change 100644 => 100755 src/transformers/utils/quantization_config.py create mode 100755 tests/quantization/hqq/test_hqq.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile old mode 100644 new mode 100755 index 08bc3c45b952..47fcd11fd766 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt # Add aqlm for quantization testing RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 +# Add hqq for quantization testing +RUN python3 -m pip install --no-cache-dir hqq + # Add autoawq for quantization testing # >=v0.2.3 needed for compatibility with torch 2.2.1 RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md old mode 100644 new mode 100755 index 8a3650a84390..f79ae9540edc --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -745,3 +745,43 @@ The speed and throughput of fused and unfused modules were also tested with the
generate throughput/batch size
+ +## HQQ +Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model. +Please refer to the official package for more details. + +For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels: +``` +pip install hqq +``` + +To quantize a model, you need to create an ```HqqConfig``` as follows: +``` Python +from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig + +#Linear layers will use the same quantization config +quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default + +#Each type of linear layer (referred to as linear tag) will use different quantization parameters +q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} +q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} +quant_config = HqqConfig(dynamic_config={ + 'self_attn.q_proj':q4_config, + 'self_attn.k_proj':q4_config, + 'self_attn.v_proj':q4_config, + 'self_attn.o_proj':q4_config, + + 'mlp.gate_proj':q3_config, + 'mlp.up_proj' :q3_config, + 'mlp.down_proj':q3_config, + }) +``` + +Then you simply quantize the model as follows +``` Python +model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda", quantization_config=quant_config) +``` +### Optimized Runtime +HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training. +For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090. +For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend \ No newline at end of file diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md old mode 100644 new mode 100755 diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py old mode 100644 new mode 100755 index 3ce3e057a240..fbed2360fdbe --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1133,6 +1133,7 @@ "EetqConfig", "GPTQConfig", "QuantoConfig", + "HqqConfig", ], } diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py old mode 100644 new mode 100755 index 72fdf3e1bbb9..a9042ce6d0af --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -18,6 +18,7 @@ _import_structure = { "aqlm": ["replace_with_aqlm_linear"], + "hqq": ["prepare_for_hqq_linear"], "awq": [ "fuse_awq_modules", "post_init_awq_exllama_modules", diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py new file mode 100755 index 000000000000..b2ba47ef3887 --- /dev/null +++ b/src/transformers/integrations/hqq.py @@ -0,0 +1,103 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"HQQ (Half-Quadratic Quantization) integration file" + +import torch + +from ..utils import is_hqq_available, logging +from ..utils.hqq_utils import autoname_modules, get_linear_tags, name_to_linear_tag + + +if is_hqq_available(): + from hqq.core.quantize import HQQLinear +else: + HQQLinear = None + +logger = logging.get_logger(__name__) + + +def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, torch.nn.Linear): + # Get linear tag + linear_tag = name_to_linear_tag(module.name) + + # We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param() + if linear_tag in patch_params: + if patch_params[linear_tag] is not None: + model._modules[name].quant_config = patch_params[linear_tag] + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + has_been_replaced = True + + if len(list(module.children())) > 0: + _, has_been_replaced = _prepare_for_hqq_linear( + module, + patch_params=patch_params, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + + return model, has_been_replaced + + +def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False): + """ + Prepares nn.Linear layers for HQQ quantization. + Since each layer type can have separate quantization parameters, we need to do the following: + 1- tag each module with its neme via autoname_modules() + 2- Extract linear_tags (e.g. ['self_attn.q_proj', ...]) + 3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params + """ + + modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert + + # Add name to module + autoname_modules(model) + + # Get linear tags. This allows us to use different quant params to different layer types + linear_tags = get_linear_tags(model) + + # Convert quantization_config to layer-wise config + skip_modules = quantization_config.skip_modules + quant_config = quantization_config.to_dict() + linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) + + if True in [(key in linear_tags) for key in quant_config.keys()]: + # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None) + patch_params = {key: None for key in linear_tags} + patch_params.update(quant_config) + else: + # Same quant_config for all layers + patch_params = {k: quant_config for k in linear_tags} + + model, has_been_replaced = _prepare_for_hqq_linear( + model, patch_params=patch_params, has_been_replaced=has_been_replaced + ) + + # We store quantization config as linear_tag -> hqq quant config + model.config.quantization_config = patch_params + + if not has_been_replaced: + logger.warning("No linear modules were found in your model for quantization.") + + return model diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py old mode 100644 new mode 100755 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py old mode 100644 new mode 100755 index be164e8e2c0c..5f6cd7cf14df --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -38,6 +38,7 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss, Identity from torch.utils.checkpoint import checkpoint +from tqdm import tqdm as tqdm_lib from .activations import get_activation from .configuration_utils import PretrainedConfig @@ -808,7 +809,13 @@ def _load_state_dict_into_meta_model( for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) - for param_name, param in state_dict.items(): + # Show shard-level progress. Useful to monitor quantization progress + show_progress = False + if hf_quantizer is not None: + if hasattr(hf_quantizer, "show_progress"): + show_progress = hf_quantizer.show_progress + + for param_name, param in tqdm_lib(state_dict.items(), disable=not show_progress): # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if param_name not in loaded_state_dict_keys or param_name not in expected_keys: continue @@ -2656,6 +2663,8 @@ def get_memory_footprint(self, return_buffers=True): @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.to` is not supported for HQQ-quantized models.") # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( @@ -2667,6 +2676,8 @@ def cuda(self, *args, **kwargs): @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.to` is not supported for HQQ-quantized models.") # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( diff --git a/src/transformers/quantizers/__init__.py b/src/transformers/quantizers/__init__.py old mode 100644 new mode 100755 diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py old mode 100644 new mode 100755 index cc58cd7af69f..b42d2337f4aa --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -21,6 +21,7 @@ BitsAndBytesConfig, EetqConfig, GPTQConfig, + HqqConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -31,6 +32,7 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_gptq import GptqHfQuantizer +from .quantizer_hqq import HQQHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -42,6 +44,7 @@ "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, + "hqq": HQQHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -52,6 +55,7 @@ "gptq": GPTQConfig, "aqlm": AqlmConfig, "quanto": QuantoConfig, + "hqq": HqqConfig, } diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py new file mode 100755 index 000000000000..41b57703c58c --- /dev/null +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..integrations import prepare_for_hqq_linear +from ..utils import is_hqq_available, is_torch_available, logging +from ..utils.hqq_utils import find_parent +from .quantizers_utils import get_module_from_name + + +if is_torch_available(): + import torch + +if is_hqq_available(): + from hqq.core.quantize import HQQLinear +else: + HQQLinear = None + +logger = logging.get_logger(__name__) + + +class HQQHfQuantizer(HfQuantizer): + """ + HQQ quantizer base HF class. + nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading(). + The actually quantization and offloading to the GPU is done in check_quantized_param(). + self.show_progress (bool) is used to show quantization progress in each shard. + """ + + use_keep_in_fp32_modules = False + requires_parameters_quantization = True + requires_calibration = False + required_packages = ["hqq"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + self.show_progress = quantization_config.show_progress + self.torch_dtype = None + + def validate_environment(self, *args, **kwargs): + if not (is_hqq_available()): + raise ImportError( + "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`" + ) + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + + if self.torch_dtype is None: + self.torch_dtype = kwargs.get("torch_dtype", torch.float16) + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + module, tensor_name = get_module_from_name(model, param_name) + + if isinstance(module, torch.nn.Linear): + return True + else: + return False + + return True + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: List[str], + ): + """ + Each nn.Linear layer is processsed here. + We first check if the corresponding module state_dict contains already HQQ quantized parameters. + If not, we create a temp linear layer with the module state_dict params and use it for quantization + """ + + module, tensor_name = get_module_from_name(model, param_name) + + # We only quantize torch.nn.Linear, other layers will be skipped + if type(module) is not torch.nn.Linear: + return + + layer_name = param_name.replace(".weight", "").replace(".bias", "") + parent_module = find_parent(model, layer_name) + node = layer_name.split(".")[-1] + + # Step 0: set module state_dict + module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} + + # Step 1: populate module with weight/bias from module state dict + for key in module_state_dict: + setattr(module, key, torch.nn.Parameter(module_state_dict[key])) + + """ + Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module + directly doesn't work. + """ + + if hasattr(module, "quant_config"): + hqq_layer = HQQLinear( + module, + module.quant_config, + compute_dtype=self.torch_dtype, + device=target_device, + del_orig=True, + ) + + setattr( + parent_module, + node, + hqq_layer, + ) + + else: + module = module.to(dtype=self.torch_dtype, device=target_device) + setattr(parent_module, node, module) + + torch.cuda.empty_cache() + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear > HQQLinear in create_quantized_param() + model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) + + # model.config.quantization_config is done inside prepare_for_hqq_linear + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + model.is_hqq_quantized = True + model.is_hqq_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py old mode 100644 new mode 100755 index e4ff991ed75c..b4b0181e0dbc --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -112,6 +112,7 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, + is_hqq_available, is_bs4_available, is_coloredlogs_available, is_cv2_available, diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py new file mode 100755 index 000000000000..658ec19a5936 --- /dev/null +++ b/src/transformers/utils/hqq_utils.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import is_hqq_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_hqq_available(): + from hqq.core.quantize import HQQLinear +else: + HQQLinear = None + + +# Name all modules inside the model +def autoname_modules(model): + for name, module in model.named_modules(): + module.name = name + + +# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj +def name_to_linear_tag(name): + return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) + + +# Get all linear tags available +def get_linear_tags(model): + linear_tags = set() + for name, module in model.named_modules(): + if type(module) in [torch.nn.Linear, HQQLinear]: + linear_tags.add(name_to_linear_tag(name)) + return list(linear_tags) + + +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + +# checks if a module is a leaf: doesn't have another module inside +def is_leaf_module(module): + return len(module._modules) == 0 + + +# Returns layers to ignores. These layers are typically not leaves we are interested in for storage and loading +def get_ignore_layers(model): + layers = {""} + for name, module in model.named_modules(): + if not is_leaf_module(module): + layers.add(name) + return list(layers) + + +# Checks if a quant config is an HQQ quant config +def check_if_hqq_quant_config(quant_config): + if quant_config is None: + return False + q_keys = list(quant_config.keys()) + q_vals = [quant_config[k] for k in quant_config][0] + if isinstance(q_vals, dict): + q_keys = q_keys + list([quant_config[k] for k in quant_config][0].keys()) + return "weight_quant_params" in q_keys + + +# Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict +@torch.no_grad() +def load_hqq_module(module, weights, compute_dtype, device): + if module.name not in weights: + try: + return module.to(compute_dtype).cuda(device) + except Exception: + return module + + state_dict = weights[module.name] + if ("W_q" in state_dict) and ("meta" in state_dict): + module = HQQLinear(linear_layer=None, quant_config=None, compute_dtype=compute_dtype, device=device) + module.load_state_dict(state_dict) + else: + for key in state_dict: + setattr(module, key, torch.nn.Parameter(state_dict[key].to(compute_dtype).to(device), requires_grad=False)) + + return module diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py old mode 100644 new mode 100755 index 8374ddef81d5..46654181c777 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -24,12 +24,20 @@ from packaging import version -from ..utils import is_auto_awq_available, is_torch_available, logging +from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging if is_torch_available(): import torch +if is_hqq_available(): + from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + + hqq_default_config = HQQBaseQuantizeConfig(nbits=4, group_size=64, offload_meta=False) +else: + HQQBaseQuantizeConfig = None + hqq_default_config = None + logger = logging.get_logger(__name__) @@ -41,6 +49,7 @@ class QuantizationMethod(str, Enum): AQLM = "aqlm" QUANTO = "quanto" EETQ = "eetq" + HQQ = "hqq" class AWQLinearVersion(str, Enum): @@ -180,6 +189,117 @@ def update(self, **kwargs): return unused_kwargs +@dataclass +class HqqConfig(QuantizationConfigMixin): + """ + Main HqqConfig. + Args: + nbits (`int`): + Number of bits. + group_size (`int`): + Group-size value. + quant_zero (`bool`): + Quantize the zero-point. + quant_scale (`bool`): + Quantize the scaling. + offload_meta (`bool`): + Offload the meta-data on the CPU. + view_as_float (`bool`): + View the quantized weight as float (used in distributed training) + int (`axis`): + Axis along-which grouping is performed. + dynamic_config ('dict'): + Parameters for dynamic configuration. The key is the name tag of the layer. + skip_modules (List[str]):: + nn,Linear layers to skip. + show_progress (bool): + Show tqdm quantization for each shard + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + nbits: int = 4, + group_size: int = 64, + quant_zero: bool = True, + quant_scale: bool = False, + offload_meta: bool = False, + view_as_float: bool = False, + axis: int = 0, + dynamic_config: dict | None = None, + skip_modules=["lm_head"], + show_progress=True, + **kwargs, + ): + if dynamic_config is not None: + self.quant_config = {} + for key in dynamic_config: + self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key]) + else: + self.quant_config = HQQBaseQuantizeConfig( + **{ + "nbits": nbits, + "group_size": group_size, + "quant_zero": quant_zero, + "quant_scale": quant_scale, + "offload_meta": offload_meta, + "view_as_float": view_as_float, + "axis": axis, + } + ) + + self.quant_method = QuantizationMethod.HQQ + self.skip_modules = skip_modules + self.show_progress = show_progress + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + pass + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return True + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return self.quant_config + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = HqqConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + @dataclass class BitsAndBytesConfig(QuantizationConfigMixin): """ diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py new file mode 100755 index 000000000000..393e80ed1bd2 --- /dev/null +++ b/tests/quantization/hqq/test_hqq.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig +from transformers.testing_utils import ( + require_accelerate, + require_torch_gpu, + slow, +) +from transformers.utils import is_accelerate_available, is_hqq_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + pass + +if is_hqq_available(): + from hqq.core.quantize import HQQBackend, HQQLinear + + +@require_torch_gpu +class HqqConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = HqqConfig() + hqq_orig_config = quantization_config.to_dict() + + for key in hqq_orig_config: + self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key]) + + +class HQQLLMRunner: + def __init__(self, model_id, quant_config=None, compute_dtype=torch.float16, device="cuda", cache_dir=None): + self.model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=compute_dtype, + device_map=device, + quantization_config=quant_config, + cache_dir=cache_dir, + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) + self.device = self.model.device + HQQLinear.set_backend(HQQBackend.PYTORCH) + + +def cleanup(): + torch.cuda.empty_cache() + gc.collect() + + +@slow +@require_torch_gpu +@require_accelerate +class HQQTest(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_small_mistral_fp16_quantized_model(self): + """ + Simple LLM model testing fp16 + """ + compute_dtype = torch.float16 + device = "cuda:0" + cache_dir = None + + quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + hqq_runner = HQQLLMRunner( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + quant_config=quant_config, + compute_dtype=compute_dtype, + cache_dir=cache_dir, + device=device, + ) + + batch_size, context_size = 1, 1024 + + # Test HQQ layer + hqq_layer = hqq_runner.model.model.layers[10].self_attn.v_proj + W_r = hqq_layer.dequantize() + x = torch.randn((batch_size, context_size, 4096), device=device, dtype=compute_dtype) / 10.0 + with torch.no_grad(): + y = hqq_layer(x) + self.assertEqual(y.shape[-1], W_r.shape[0]) + self.assertEqual(y.dtype, compute_dtype) + + del W_r, x, y + cleanup() + + # Test forward pass + with torch.no_grad(): + out = hqq_runner.model( + torch.zeros([batch_size, context_size], device=hqq_runner.model.device, dtype=torch.int32) + ).logits + self.assertEqual(out.shape[0], batch_size) + self.assertEqual(out.shape[1], context_size) + + def test_mistral_bfp16_offloading_quantized_model(self): + """ + Simple LLM model testing bfp16 with offfloading + """ + compute_dtype = torch.bfloat16 + device = "cuda:0" + cache_dir = None + + q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} + q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False} + quant_config = HqqConfig( + dynamic_config={ + "self_attn.q_proj": q4_config, + "self_attn.k_proj": q4_config, + "self_attn.v_proj": q4_config, + "self_attn.o_proj": q4_config, + "mlp.gate_proj": q3_config, + "mlp.up_proj": q3_config, + "mlp.down_proj": q3_config, + } + ) + + hqq_runner = HQQLLMRunner( + model_id="mistralai/Mistral-7B-Instruct-v0.2", + quant_config=quant_config, + compute_dtype=compute_dtype, + cache_dir=cache_dir, + device=device, + ) + + batch_size, context_size = 1, 1024 + + # Test HQQ layer + hqq_layer = hqq_runner.model.model.layers[10].self_attn.v_proj + W_r = hqq_layer.dequantize() + x = torch.randn((batch_size, context_size, 4096), device=device, dtype=compute_dtype) / 10.0 + with torch.no_grad(): + y = hqq_layer(x) + self.assertEqual(y.shape[-1], W_r.shape[0]) + self.assertEqual(y.dtype, compute_dtype) + + # Check device + self.assertEqual(hqq_layer.W_q.device.type, "cuda") + self.assertEqual(hqq_layer.meta["zero_scale"].device.type, "cpu") + + del W_r, x, y + cleanup() + + # Test forward pass + with torch.no_grad(): + out = hqq_runner.model( + torch.zeros([batch_size, context_size], device=hqq_runner.model.device, dtype=torch.int32) + ).logits + self.assertEqual(out.shape[0], batch_size) + self.assertEqual(out.shape[1], context_size) From e1e5df68559ef0bd7d7dfda473728ebdfe2b5191 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 14:05:54 +0000 Subject: [PATCH 02/73] push import_utils.py --- src/transformers/utils/import_utils.py | 5 +++++ 1 file changed, 5 insertions(+) mode change 100644 => 100755 src/transformers/utils/import_utils.py diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py old mode 100644 new mode 100755 index c65d4122b787..158896347a7a --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -170,6 +170,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _mlx_available = _is_package_available("mlx") +_hqq_available = _is_package_available("hqq") _torch_version = "N/A" @@ -292,6 +293,10 @@ def is_torch_available(): return _torch_available +def is_hqq_available(): + return _hqq_available + + def get_torch_version(): return _torch_version From 0192b03b81e03aeee56330dc8601435f90e107b4 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 14:59:22 +0000 Subject: [PATCH 03/73] add force_hooks check in modeling_utils.py --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5f6cd7cf14df..2ed5e9fae678 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3744,6 +3744,9 @@ def from_pretrained( } if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + # For HQQ method we force-set the hooks for single GPU envs + if "force_hooks" in inspect.signature(dispatch_model).parameters and hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ: + device_map_kwargs["force_hooks"] = True if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) From 823de37282347d4f97f079027027a0240b3fa912 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 15:00:46 +0000 Subject: [PATCH 04/73] fix | with Optional --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 46654181c777..1476a68c8b5f 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -227,7 +227,7 @@ def __init__( offload_meta: bool = False, view_as_float: bool = False, axis: int = 0, - dynamic_config: dict | None = None, + dynamic_config: Optional[dict] = None, skip_modules=["lm_head"], show_progress=True, **kwargs, From 08d7b8e644bebbff85ba29612b4f4713121a7b85 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 15:05:53 +0000 Subject: [PATCH 05/73] force bias as param --- src/transformers/quantizers/quantizer_hqq.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 41b57703c58c..97f0f084a717 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -136,6 +136,9 @@ def create_quantized_param( del_orig=True, ) + if(hqq_layer.bias is not None): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + setattr( parent_module, node, From e1fa6c96252e78b5b3709327c79ce49b527f67db Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 15:08:30 +0000 Subject: [PATCH 06/73] check bias is Tensor --- src/transformers/quantizers/quantizer_hqq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 97f0f084a717..f22a30181ecf 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -136,8 +136,9 @@ def create_quantized_param( del_orig=True, ) - if(hqq_layer.bias is not None): - hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + if hqq_layer.bias is not None: + if isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) setattr( parent_module, From 6e854cae6a45705fa0db2a4db3b3b23c75f57444 Mon Sep 17 00:00:00 2001 From: mobicham Date: Wed, 24 Apr 2024 15:58:25 +0000 Subject: [PATCH 07/73] force forward for multi-gpu --- src/transformers/quantizers/quantizer_hqq.py | 27 +++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index f22a30181ecf..a9302f69f3a3 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -21,11 +21,14 @@ from ..modeling_utils import PreTrainedModel from ..integrations import prepare_for_hqq_linear -from ..utils import is_hqq_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging from ..utils.hqq_utils import find_parent from .quantizers_utils import get_module_from_name +if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + if is_torch_available(): import torch @@ -54,6 +57,7 @@ def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) self.show_progress = quantization_config.show_progress self.torch_dtype = None + self.using_multi_gpu = False def validate_environment(self, *args, **kwargs): if not (is_hqq_available()): @@ -73,6 +77,11 @@ def validate_environment(self, *args, **kwargs): if self.torch_dtype is None: self.torch_dtype = kwargs.get("torch_dtype", torch.float16) + if self.using_multi_gpu is False: + if "device_map" in kwargs: + if isinstance(kwargs["device_map"], dict): + self.using_multi_gpu = len({item[1] for item in kwargs["device_map"].items()}) > 1 + def check_quantized_param( self, model: "PreTrainedModel", @@ -140,6 +149,9 @@ def create_quantized_param( if isinstance(hqq_layer.bias, torch.Tensor): hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + if self.using_multi_gpu: + hqq_layer = self._patch_layer_for_multigpu(hqq_layer) + setattr( parent_module, node, @@ -152,6 +164,19 @@ def create_quantized_param( torch.cuda.empty_cache() + # Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu + def _patch_layer_for_multigpu(self, hqq_layer): + hqq_layer = remove_hook_from_module(hqq_layer) + + def forward_with_device(self, x): + out = torch.matmul(x.to(self.device), self.dequantize().t()) + if self.bias is not None: + out += self.bias + return out + + hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x) + return hqq_layer + def _process_model_before_weight_loading( self, model: "PreTrainedModel", From 2b9f271adeb8db20fa397d71b9b3ce773dd87fef Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 12:01:44 +0000 Subject: [PATCH 08/73] review fixes pass --- src/transformers/integrations/hqq.py | 9 ++++++--- src/transformers/modeling_utils.py | 12 ++++++++---- src/transformers/quantizers/quantizer_hqq.py | 10 +++------- src/transformers/utils/hqq_utils.py | 11 ++++++----- src/transformers/utils/quantization_config.py | 9 --------- 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index b2ba47ef3887..0c62098ca3ed 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -13,12 +13,15 @@ # limitations under the License. "HQQ (Half-Quadratic Quantization) integration file" -import torch - -from ..utils import is_hqq_available, logging +from ..utils import is_hqq_available, is_torch_available, logging from ..utils.hqq_utils import autoname_modules, get_linear_tags, name_to_linear_tag +if is_torch_available(): + import torch +else: + torch = None + if is_hqq_available(): from hqq.core.quantize import HQQLinear else: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2ed5e9fae678..c94697b20e1c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -810,12 +810,12 @@ def _load_state_dict_into_meta_model( state_dict[new_key] = state_dict.pop(old_key) # Show shard-level progress. Useful to monitor quantization progress - show_progress = False + quant_show_progress = False if hf_quantizer is not None: if hasattr(hf_quantizer, "show_progress"): - show_progress = hf_quantizer.show_progress + quant_show_progress = hf_quantizer.show_progress - for param_name, param in tqdm_lib(state_dict.items(), disable=not show_progress): + for param_name, param in tqdm_lib(state_dict.items(), disable=not quant_show_progress): # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if param_name not in loaded_state_dict_keys or param_name not in expected_keys: continue @@ -3745,7 +3745,11 @@ def from_pretrained( if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement # For HQQ method we force-set the hooks for single GPU envs - if "force_hooks" in inspect.signature(dispatch_model).parameters and hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ: + if ( + "force_hooks" in inspect.signature(dispatch_model).parameters + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ + ): device_map_kwargs["force_hooks"] = True if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index a9302f69f3a3..6075d51dbdad 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -32,11 +32,6 @@ if is_torch_available(): import torch -if is_hqq_available(): - from hqq.core.quantize import HQQLinear -else: - HQQLinear = None - logger = logging.get_logger(__name__) @@ -97,8 +92,6 @@ def check_quantized_param( else: return False - return True - def create_quantized_param( self, model: "PreTrainedModel", @@ -114,6 +107,9 @@ def create_quantized_param( If not, we create a temp linear layer with the module state_dict params and use it for quantization """ + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + module, tensor_name = get_module_from_name(model, param_name) # We only quantize torch.nn.Linear, other layers will be skipped diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py index 658ec19a5936..c2da5d036cbc 100755 --- a/src/transformers/utils/hqq_utils.py +++ b/src/transformers/utils/hqq_utils.py @@ -18,11 +18,6 @@ if is_torch_available(): import torch -if is_hqq_available(): - from hqq.core.quantize import HQQLinear -else: - HQQLinear = None - # Name all modules inside the model def autoname_modules(model): @@ -37,6 +32,9 @@ def name_to_linear_tag(name): # Get all linear tags available def get_linear_tags(model): + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + linear_tags = set() for name, module in model.named_modules(): if type(module) in [torch.nn.Linear, HQQLinear]: @@ -81,6 +79,9 @@ def check_if_hqq_quant_config(quant_config): # Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict @torch.no_grad() def load_hqq_module(module, weights, compute_dtype, device): + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + if module.name not in weights: try: return module.to(compute_dtype).cuda(device) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 1476a68c8b5f..aae6dd3eecbc 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -32,11 +32,8 @@ if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig - - hqq_default_config = HQQBaseQuantizeConfig(nbits=4, group_size=64, offload_meta=False) else: HQQBaseQuantizeConfig = None - hqq_default_config = None logger = logging.get_logger(__name__) @@ -261,12 +258,6 @@ def post_init(self): """ pass - def is_quantizable(self): - r""" - Returns `True` if the model is quantizable, `False` otherwise. - """ - return True - def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: From 5bb9ca25cef6ef8361979eb183eba5df7dc8f204 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 14:17:18 +0000 Subject: [PATCH 09/73] remove torch grad() --- src/transformers/utils/hqq_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py index c2da5d036cbc..2091ad980d06 100755 --- a/src/transformers/utils/hqq_utils.py +++ b/src/transformers/utils/hqq_utils.py @@ -77,14 +77,13 @@ def check_if_hqq_quant_config(quant_config): # Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict -@torch.no_grad() def load_hqq_module(module, weights, compute_dtype, device): if is_hqq_available(): from hqq.core.quantize import HQQLinear if module.name not in weights: try: - return module.to(compute_dtype).cuda(device) + return module.to(dtype=compute_dtype, device=device) except Exception: return module @@ -94,6 +93,10 @@ def load_hqq_module(module, weights, compute_dtype, device): module.load_state_dict(state_dict) else: for key in state_dict: - setattr(module, key, torch.nn.Parameter(state_dict[key].to(compute_dtype).to(device), requires_grad=False)) + setattr( + module, + key, + torch.nn.Parameter(state_dict[key].to(device=device, dtype=compute_dtype), requires_grad=False), + ) return module From 392e7c5e3689b9053f25b2f2a355a2aa66b3dafc Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 14:19:45 +0000 Subject: [PATCH 10/73] if any key in linear_tags fix --- src/transformers/integrations/hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 0c62098ca3ed..7223731db9bd 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -85,7 +85,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve quant_config = quantization_config.to_dict() linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) - if True in [(key in linear_tags) for key in quant_config.keys()]: + if any(key in linear_tags for key in quant_config.keys()): # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None) patch_params = {key: None for key in linear_tags} patch_params.update(quant_config) From 20f9ad5bfa584064c1ab00ec60f288aa64324e76 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 14:28:36 +0000 Subject: [PATCH 11/73] add cpu/disk check --- src/transformers/quantizers/quantizer_hqq.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 6075d51dbdad..5f185650e0bd 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -75,7 +75,8 @@ def validate_environment(self, *args, **kwargs): if self.using_multi_gpu is False: if "device_map" in kwargs: if isinstance(kwargs["device_map"], dict): - self.using_multi_gpu = len({item[1] for item in kwargs["device_map"].items()}) > 1 + target_devices = {item[1] for item in kwargs["device_map"].items()} - set({"cpu", "disk"}) + self.using_multi_gpu = len(target_devices) > 1 def check_quantized_param( self, From 3a5679a9eb5d093246d00d6edd8265409900ec4e Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 14:33:41 +0000 Subject: [PATCH 12/73] isinstance return --- src/transformers/quantizers/quantizer_hqq.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 5f185650e0bd..2321a1d03c30 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -88,10 +88,7 @@ def check_quantized_param( ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, torch.nn.Linear): - return True - else: - return False + return isinstance(module, torch.nn.Linear) def create_quantized_param( self, From 7a1bbca2ac992313cb64415e4d205f3d1552dc0c Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 15:26:51 +0000 Subject: [PATCH 13/73] add multigpu test + refactor tests --- tests/quantization/hqq/test_hqq.py | 122 +++++++++++++++++------------ 1 file changed, 71 insertions(+), 51 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 393e80ed1bd2..ac625dfb5f33 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -20,6 +20,7 @@ from transformers.testing_utils import ( require_accelerate, require_torch_gpu, + require_torch_multi_gpu, slow, ) from transformers.utils import is_accelerate_available, is_hqq_available, is_torch_available @@ -55,6 +56,7 @@ def __init__(self, model_id, quant_config=None, compute_dtype=torch.float16, dev torch_dtype=compute_dtype, device_map=device, quantization_config=quant_config, + low_cpu_mem_usage=True, cache_dir=cache_dir, ) self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) @@ -67,16 +69,46 @@ def cleanup(): gc.collect() +def test_hqqlayer(self, hqq_layer, batch_size=1, context_size=1024): + # Test HQQ layer + W_r = hqq_layer.dequantize() + x = ( + torch.randn( + (batch_size, context_size, hqq_layer.meta["shape"][1]), + device=hqq_layer.device, + dtype=hqq_layer.compute_dtype, + ) + / 10.0 + ) + with torch.no_grad(): + y = hqq_layer(x) + self.assertEqual(y.shape[-1], W_r.shape[0]) + self.assertEqual(y.dtype, hqq_layer.compute_dtype) + del W_r, x, y + cleanup() + + +def test_forward(self, model, batch_size=1, context_size=1024): + # Test forward pass + with torch.no_grad(): + out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits + self.assertEqual(out.shape[0], batch_size) + self.assertEqual(out.shape[1], context_size) + cleanup() + + +model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +# model_id ="mistralai/Mistral-7B-Instruct-v0.2" + + @slow @require_torch_gpu @require_accelerate class HQQTest(unittest.TestCase): def tearDown(self): - gc.collect() - torch.cuda.empty_cache() - gc.collect() + cleanup() - def test_small_mistral_fp16_quantized_model(self): + def test_fp16_quantized_model(self): """ Simple LLM model testing fp16 """ @@ -85,46 +117,28 @@ def test_small_mistral_fp16_quantized_model(self): cache_dir = None quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + hqq_runner = HQQLLMRunner( - model_id="mistralai/Mistral-7B-Instruct-v0.2", + model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, cache_dir=cache_dir, device=device, ) - batch_size, context_size = 1, 1024 - - # Test HQQ layer - hqq_layer = hqq_runner.model.model.layers[10].self_attn.v_proj - W_r = hqq_layer.dequantize() - x = torch.randn((batch_size, context_size, 4096), device=device, dtype=compute_dtype) / 10.0 - with torch.no_grad(): - y = hqq_layer(x) - self.assertEqual(y.shape[-1], W_r.shape[0]) - self.assertEqual(y.dtype, compute_dtype) - - del W_r, x, y - cleanup() - - # Test forward pass - with torch.no_grad(): - out = hqq_runner.model( - torch.zeros([batch_size, context_size], device=hqq_runner.model.device, dtype=torch.int32) - ).logits - self.assertEqual(out.shape[0], batch_size) - self.assertEqual(out.shape[1], context_size) + test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + test_forward(self, hqq_runner.model) - def test_mistral_bfp16_offloading_quantized_model(self): + def test_bfp16_quantized_model_with_offloading(self): """ - Simple LLM model testing bfp16 with offfloading + Simple LLM model testing bfp16 with meta-data offloading """ compute_dtype = torch.bfloat16 device = "cuda:0" cache_dir = None q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} - q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False} + q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True} quant_config = HqqConfig( dynamic_config={ "self_attn.q_proj": q4_config, @@ -138,35 +152,41 @@ def test_mistral_bfp16_offloading_quantized_model(self): ) hqq_runner = HQQLLMRunner( - model_id="mistralai/Mistral-7B-Instruct-v0.2", + model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, cache_dir=cache_dir, device=device, ) - batch_size, context_size = 1, 1024 - - # Test HQQ layer - hqq_layer = hqq_runner.model.model.layers[10].self_attn.v_proj - W_r = hqq_layer.dequantize() - x = torch.randn((batch_size, context_size, 4096), device=device, dtype=compute_dtype) / 10.0 - with torch.no_grad(): - y = hqq_layer(x) - self.assertEqual(y.shape[-1], W_r.shape[0]) - self.assertEqual(y.dtype, compute_dtype) + test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + test_forward(self, hqq_runner.model) - # Check device - self.assertEqual(hqq_layer.W_q.device.type, "cuda") - self.assertEqual(hqq_layer.meta["zero_scale"].device.type, "cpu") - del W_r, x, y +@slow +@require_torch_gpu +@require_torch_multi_gpu +@require_accelerate +class HQQTestMultiGPU(unittest.TestCase): + def tearDown(self): cleanup() - # Test forward pass - with torch.no_grad(): - out = hqq_runner.model( - torch.zeros([batch_size, context_size], device=hqq_runner.model.device, dtype=torch.int32) - ).logits - self.assertEqual(out.shape[0], batch_size) - self.assertEqual(out.shape[1], context_size) + def test_fp16_quantized_model_multipgpu(self): + """ + Simple LLM model testing fp16 with multi-gpu + """ + compute_dtype = torch.float16 + cache_dir = None + + quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + + hqq_runner = HQQLLMRunner( + model_id=model_id, + quant_config=quant_config, + compute_dtype=compute_dtype, + cache_dir=cache_dir, + device="auto", + ) + + test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + test_forward(self, hqq_runner.model) From 65b288799287d7fa979a3745351512efb5449f2e Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 15:47:26 +0000 Subject: [PATCH 14/73] clean hqq_utils imports in hqq.py --- src/transformers/integrations/hqq.py | 24 +++++++++++++++++++++++- src/transformers/utils/hqq_utils.py | 23 ----------------------- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 7223731db9bd..ab9125bb9f9b 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -14,7 +14,6 @@ "HQQ (Half-Quadratic Quantization) integration file" from ..utils import is_hqq_available, is_torch_available, logging -from ..utils.hqq_utils import autoname_modules, get_linear_tags, name_to_linear_tag if is_torch_available(): @@ -30,6 +29,29 @@ logger = logging.get_logger(__name__) +# Name all modules inside the model +def autoname_modules(model): + for name, module in model.named_modules(): + module.name = name + + +# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj +def name_to_linear_tag(name): + return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) + + +# Get all linear tags available +def get_linear_tags(model): + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + + linear_tags = set() + for name, module in model.named_modules(): + if type(module) in [torch.nn.Linear, HQQLinear]: + linear_tags.add(name_to_linear_tag(name)) + return list(linear_tags) + + def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None): for name, module in model.named_children(): if current_key_name is None: diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py index 2091ad980d06..652a5b2c91e0 100755 --- a/src/transformers/utils/hqq_utils.py +++ b/src/transformers/utils/hqq_utils.py @@ -19,29 +19,6 @@ import torch -# Name all modules inside the model -def autoname_modules(model): - for name, module in model.named_modules(): - module.name = name - - -# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj -def name_to_linear_tag(name): - return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))]) - - -# Get all linear tags available -def get_linear_tags(model): - if is_hqq_available(): - from hqq.core.quantize import HQQLinear - - linear_tags = set() - for name, module in model.named_modules(): - if type(module) in [torch.nn.Linear, HQQLinear]: - linear_tags.add(name_to_linear_tag(name)) - return list(linear_tags) - - # Finds the parent of a node module named "name" def find_parent(model, name): module_tree = name.split(".")[:-1] From bba74cd2aa6ba8c165cf3f637819ac6a594f9196 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 15:57:06 +0000 Subject: [PATCH 15/73] clean hqq_utils imports in quantizer_hqq.py --- src/transformers/quantizers/quantizer_hqq.py | 10 +++++++++- src/transformers/utils/hqq_utils.py | 9 --------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 2321a1d03c30..5f9a821dc608 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -22,7 +22,6 @@ from ..integrations import prepare_for_hqq_linear from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging -from ..utils.hqq_utils import find_parent from .quantizers_utils import get_module_from_name @@ -35,6 +34,15 @@ logger = logging.get_logger(__name__) +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + class HQQHfQuantizer(HfQuantizer): """ HQQ quantizer base HF class. diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py index 652a5b2c91e0..72b73e4dc092 100755 --- a/src/transformers/utils/hqq_utils.py +++ b/src/transformers/utils/hqq_utils.py @@ -19,15 +19,6 @@ import torch -# Finds the parent of a node module named "name" -def find_parent(model, name): - module_tree = name.split(".")[:-1] - parent = model - for m in module_tree: - parent = parent._modules[m] - return parent - - # checks if a module is a leaf: doesn't have another module inside def is_leaf_module(module): return len(module._modules) == 0 From de88c2afa17048f4c7faab31319d7e06bf9e0056 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 16:03:52 +0000 Subject: [PATCH 16/73] delete hqq_utils.py --- src/transformers/utils/hqq_utils.py | 80 ++++++++++++++--------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py index 72b73e4dc092..9b4c750ca3b5 100755 --- a/src/transformers/utils/hqq_utils.py +++ b/src/transformers/utils/hqq_utils.py @@ -19,52 +19,52 @@ import torch -# checks if a module is a leaf: doesn't have another module inside -def is_leaf_module(module): - return len(module._modules) == 0 +# # checks if a module is a leaf: doesn't have another module inside +# def is_leaf_module(module): +# return len(module._modules) == 0 -# Returns layers to ignores. These layers are typically not leaves we are interested in for storage and loading -def get_ignore_layers(model): - layers = {""} - for name, module in model.named_modules(): - if not is_leaf_module(module): - layers.add(name) - return list(layers) +# # Returns layers to ignores. These layers are typically not leaves we are interested in for storage and loading +# def get_ignore_layers(model): +# layers = {""} +# for name, module in model.named_modules(): +# if not is_leaf_module(module): +# layers.add(name) +# return list(layers) -# Checks if a quant config is an HQQ quant config -def check_if_hqq_quant_config(quant_config): - if quant_config is None: - return False - q_keys = list(quant_config.keys()) - q_vals = [quant_config[k] for k in quant_config][0] - if isinstance(q_vals, dict): - q_keys = q_keys + list([quant_config[k] for k in quant_config][0].keys()) - return "weight_quant_params" in q_keys +# # Checks if a quant config is an HQQ quant config +# def check_if_hqq_quant_config(quant_config): +# if quant_config is None: +# return False +# q_keys = list(quant_config.keys()) +# q_vals = [quant_config[k] for k in quant_config][0] +# if isinstance(q_vals, dict): +# q_keys = q_keys + list([quant_config[k] for k in quant_config][0].keys()) +# return "weight_quant_params" in q_keys -# Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict -def load_hqq_module(module, weights, compute_dtype, device): - if is_hqq_available(): - from hqq.core.quantize import HQQLinear +# # Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict +# def load_hqq_module(module, weights, compute_dtype, device): +# if is_hqq_available(): +# from hqq.core.quantize import HQQLinear - if module.name not in weights: - try: - return module.to(dtype=compute_dtype, device=device) - except Exception: - return module +# if module.name not in weights: +# try: +# return module.to(dtype=compute_dtype, device=device) +# except Exception: +# return module - state_dict = weights[module.name] - if ("W_q" in state_dict) and ("meta" in state_dict): - module = HQQLinear(linear_layer=None, quant_config=None, compute_dtype=compute_dtype, device=device) - module.load_state_dict(state_dict) - else: - for key in state_dict: - setattr( - module, - key, - torch.nn.Parameter(state_dict[key].to(device=device, dtype=compute_dtype), requires_grad=False), - ) +# state_dict = weights[module.name] +# if ("W_q" in state_dict) and ("meta" in state_dict): +# module = HQQLinear(linear_layer=None, quant_config=None, compute_dtype=compute_dtype, device=device) +# module.load_state_dict(state_dict) +# else: +# for key in state_dict: +# setattr( +# module, +# key, +# torch.nn.Parameter(state_dict[key].to(device=device, dtype=compute_dtype), requires_grad=False), +# ) - return module +# return module From 651a5863319b070c8e34378ad57c6366b88ceed3 Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:07:04 +0200 Subject: [PATCH 17/73] Delete src/transformers/utils/hqq_utils.py --- src/transformers/utils/hqq_utils.py | 70 ----------------------------- 1 file changed, 70 deletions(-) delete mode 100755 src/transformers/utils/hqq_utils.py diff --git a/src/transformers/utils/hqq_utils.py b/src/transformers/utils/hqq_utils.py deleted file mode 100755 index 9b4c750ca3b5..000000000000 --- a/src/transformers/utils/hqq_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import is_hqq_available, is_torch_available - - -if is_torch_available(): - import torch - - -# # checks if a module is a leaf: doesn't have another module inside -# def is_leaf_module(module): -# return len(module._modules) == 0 - - -# # Returns layers to ignores. These layers are typically not leaves we are interested in for storage and loading -# def get_ignore_layers(model): -# layers = {""} -# for name, module in model.named_modules(): -# if not is_leaf_module(module): -# layers.add(name) -# return list(layers) - - -# # Checks if a quant config is an HQQ quant config -# def check_if_hqq_quant_config(quant_config): -# if quant_config is None: -# return False -# q_keys = list(quant_config.keys()) -# q_vals = [quant_config[k] for k in quant_config][0] -# if isinstance(q_vals, dict): -# q_keys = q_keys + list([quant_config[k] for k in quant_config][0].keys()) -# return "weight_quant_params" in q_keys - - -# # Returns a new module from a dummy (meta) module and a dictionary of module name -> state_dict -# def load_hqq_module(module, weights, compute_dtype, device): -# if is_hqq_available(): -# from hqq.core.quantize import HQQLinear - -# if module.name not in weights: -# try: -# return module.to(dtype=compute_dtype, device=device) -# except Exception: -# return module - -# state_dict = weights[module.name] -# if ("W_q" in state_dict) and ("meta" in state_dict): -# module = HQQLinear(linear_layer=None, quant_config=None, compute_dtype=compute_dtype, device=device) -# module.load_state_dict(state_dict) -# else: -# for key in state_dict: -# setattr( -# module, -# key, -# torch.nn.Parameter(state_dict[key].to(device=device, dtype=compute_dtype), requires_grad=False), -# ) - -# return module From d07ea850290b9215ad06ddf06dcdb3ea4ff89395 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 16:13:37 +0000 Subject: [PATCH 18/73] ruff init --- src/transformers/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b4b0181e0dbc..2bfa5638df92 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -112,7 +112,6 @@ is_auto_gptq_available, is_av_available, is_bitsandbytes_available, - is_hqq_available, is_bs4_available, is_coloredlogs_available, is_cv2_available, @@ -130,6 +129,7 @@ is_ftfy_available, is_g2p_en_available, is_galore_torch_available, + is_hqq_available, is_in_notebook, is_ipex_available, is_jieba_available, From dedf69ec94ed9b3ac86e45e272b6e608b4c2fbc0 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 17:07:20 +0000 Subject: [PATCH 19/73] remove torch.float16 from __init__ in test --- tests/quantization/hqq/test_hqq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index ac625dfb5f33..d9bd57f5b0fb 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -50,7 +50,7 @@ def test_to_dict(self): class HQQLLMRunner: - def __init__(self, model_id, quant_config=None, compute_dtype=torch.float16, device="cuda", cache_dir=None): + def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=compute_dtype, @@ -122,8 +122,8 @@ def test_fp16_quantized_model(self): model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, - cache_dir=cache_dir, device=device, + cache_dir=cache_dir, ) test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) @@ -155,8 +155,8 @@ def test_bfp16_quantized_model_with_offloading(self): model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, - cache_dir=cache_dir, device=device, + cache_dir=cache_dir, ) test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) @@ -184,8 +184,8 @@ def test_fp16_quantized_model_multipgpu(self): model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, - cache_dir=cache_dir, device="auto", + cache_dir=cache_dir, ) test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) From 0edf8a4390e2c8c6e01018a46d00d344c6df1e05 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 25 Apr 2024 17:10:06 +0000 Subject: [PATCH 20/73] refactor test --- tests/quantization/hqq/test_hqq.py | 50 +++++++++++++++--------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index d9bd57f5b0fb..6b59ff4f00bc 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -36,19 +36,6 @@ from hqq.core.quantize import HQQBackend, HQQLinear -@require_torch_gpu -class HqqConfigTest(unittest.TestCase): - def test_to_dict(self): - """ - Makes sure the config format is properly set - """ - quantization_config = HqqConfig() - hqq_orig_config = quantization_config.to_dict() - - for key in hqq_orig_config: - self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key]) - - class HQQLLMRunner: def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): self.model = AutoModelForCausalLM.from_pretrained( @@ -69,7 +56,7 @@ def cleanup(): gc.collect() -def test_hqqlayer(self, hqq_layer, batch_size=1, context_size=1024): +def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024): # Test HQQ layer W_r = hqq_layer.dequantize() x = ( @@ -82,18 +69,18 @@ def test_hqqlayer(self, hqq_layer, batch_size=1, context_size=1024): ) with torch.no_grad(): y = hqq_layer(x) - self.assertEqual(y.shape[-1], W_r.shape[0]) - self.assertEqual(y.dtype, hqq_layer.compute_dtype) + test_module.assertEqual(y.shape[-1], W_r.shape[0]) + test_module.assertEqual(y.dtype, hqq_layer.compute_dtype) del W_r, x, y cleanup() -def test_forward(self, model, batch_size=1, context_size=1024): +def check_forward(test_module, model, batch_size=1, context_size=1024): # Test forward pass with torch.no_grad(): out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits - self.assertEqual(out.shape[0], batch_size) - self.assertEqual(out.shape[1], context_size) + test_module.assertEqual(out.shape[0], batch_size) + test_module.assertEqual(out.shape[1], context_size) cleanup() @@ -101,6 +88,19 @@ def test_forward(self, model, batch_size=1, context_size=1024): # model_id ="mistralai/Mistral-7B-Instruct-v0.2" +@require_torch_gpu +class HqqConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Makes sure the config format is properly set + """ + quantization_config = HqqConfig() + hqq_orig_config = quantization_config.to_dict() + + for key in hqq_orig_config: + self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key]) + + @slow @require_torch_gpu @require_accelerate @@ -126,8 +126,8 @@ def test_fp16_quantized_model(self): cache_dir=cache_dir, ) - test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) - test_forward(self, hqq_runner.model) + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) def test_bfp16_quantized_model_with_offloading(self): """ @@ -159,8 +159,8 @@ def test_bfp16_quantized_model_with_offloading(self): cache_dir=cache_dir, ) - test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) - test_forward(self, hqq_runner.model) + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) @slow @@ -188,5 +188,5 @@ def test_fp16_quantized_model_multipgpu(self): cache_dir=cache_dir, ) - test_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) - test_forward(self, hqq_runner.model) + check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) + check_forward(self, hqq_runner.model) From c7ec12399b6fd6fedf828ea62b66fb77e969c03c Mon Sep 17 00:00:00 2001 From: mobicham Date: Fri, 26 Apr 2024 07:30:52 +0000 Subject: [PATCH 21/73] isinstance -> type in quantizer_hqq.py --- src/transformers/quantizers/quantizer_hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 5f9a821dc608..f5a59ff26261 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -96,7 +96,7 @@ def check_quantized_param( ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) + return type(module) is torch.nn.Linear def create_quantized_param( self, From 5283ac2043c76d917247a845295bae0b30846a3a Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 09:53:49 +0000 Subject: [PATCH 22/73] cpu/disk device_map check in quantizer_hqq.py --- src/transformers/quantizers/quantizer_hqq.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index f5a59ff26261..35de3b82eb6c 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -80,11 +80,16 @@ def validate_environment(self, *args, **kwargs): if self.torch_dtype is None: self.torch_dtype = kwargs.get("torch_dtype", torch.float16) + device_map = kwargs.get("device_map", None) if self.using_multi_gpu is False: - if "device_map" in kwargs: - if isinstance(kwargs["device_map"], dict): - target_devices = {item[1] for item in kwargs["device_map"].items()} - set({"cpu", "disk"}) - self.using_multi_gpu = len(target_devices) > 1 + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + raise ValueError( + "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + else: + self.using_multi_gpu = len(set(device_map.values())) > 1 def check_quantized_param( self, From 15daeb484cf9f2d17bca715957402755a6efa5f5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 09:56:37 +0000 Subject: [PATCH 23/73] remove type(module) nn.linear check in quantizer_hqq.py --- src/transformers/quantizers/quantizer_hqq.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 35de3b82eb6c..348f4d10e8ae 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -123,10 +123,6 @@ def create_quantized_param( module, tensor_name = get_module_from_name(model, param_name) - # We only quantize torch.nn.Linear, other layers will be skipped - if type(module) is not torch.nn.Linear: - return - layer_name = param_name.replace(".weight", "").replace(".bias", "") parent_module = find_parent(model, layer_name) node = layer_name.split(".")[-1] From bc4bc73eb35cd87ef31ab292b7846c02c5ba3c22 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:00:59 +0000 Subject: [PATCH 24/73] add BaseQuantizeConfig import inside HqqConfig init --- src/transformers/utils/quantization_config.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index aae6dd3eecbc..bc76866b3ed3 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -30,11 +30,6 @@ if is_torch_available(): import torch -if is_hqq_available(): - from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig -else: - HQQBaseQuantizeConfig = None - logger = logging.get_logger(__name__) @@ -229,6 +224,9 @@ def __init__( show_progress=True, **kwargs, ): + if is_hqq_available(): + from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + if dynamic_config is not None: self.quant_config = {} for key in dynamic_config: From b54e87b204fb72941ac081c54ad32611a3918543 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:03:24 +0000 Subject: [PATCH 25/73] remove hqq import in hqq.py --- src/transformers/integrations/hqq.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index ab9125bb9f9b..f1b38050e8b2 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -21,11 +21,6 @@ else: torch = None -if is_hqq_available(): - from hqq.core.quantize import HQQLinear -else: - HQQLinear = None - logger = logging.get_logger(__name__) From 0f9698afbc307971fd779aca80a2bc79aba12ae9 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:05:05 +0000 Subject: [PATCH 26/73] remove accelerate import from test_hqq.py --- tests/quantization/hqq/test_hqq.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 6b59ff4f00bc..275c2413219d 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -23,15 +23,12 @@ require_torch_multi_gpu, slow, ) -from transformers.utils import is_accelerate_available, is_hqq_available, is_torch_available +from transformers.utils import is_hqq_available, is_torch_available if is_torch_available(): import torch -if is_accelerate_available(): - pass - if is_hqq_available(): from hqq.core.quantize import HQQBackend, HQQLinear From d31837fb7a8f7b7e2a2fabad169fd97d211d226a Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:12:15 +0000 Subject: [PATCH 27/73] quant config.py doc update --- src/transformers/utils/quantization_config.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bc76866b3ed3..7f2a4672fba1 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -186,26 +186,26 @@ class HqqConfig(QuantizationConfigMixin): """ Main HqqConfig. Args: - nbits (`int`): + nbits (`int`, defaults to 4. Supported: 8, 4, 3, 2, 1): Number of bits. - group_size (`int`): + group_size (`int`, defaults to 64. Supported: any value that is divisble by weight.shape[axis]): Group-size value. - quant_zero (`bool`): + quant_zero (`bool`, defaults to False): Quantize the zero-point. - quant_scale (`bool`): + quant_scale (`bool`, defaults to False): Quantize the scaling. - offload_meta (`bool`): + offload_meta (`bool`, defaults to False): Offload the meta-data on the CPU. - view_as_float (`bool`): + view_as_float (`bool`, defaults to False): View the quantized weight as float (used in distributed training) - int (`axis`): - Axis along-which grouping is performed. - dynamic_config ('dict'): + int (`axis`, defaults to 0. Supported: 0, 1): + Axis along which grouping is performed. + dynamic_config ('dict', defaults to None): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (List[str]):: - nn,Linear layers to skip. - show_progress (bool): - Show tqdm quantization for each shard + skip_modules (List[str]): + List of nn.Linear layers to skip. + show_progress ('bool', defaults to True): + Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`): Additional parameters from which to initialize the configuration object. """ From b8f792c74284670196c600bdcd73f55a81caca42 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:15:46 +0000 Subject: [PATCH 28/73] add hqqconfig to main_classes doc --- docs/source/en/main_classes/quantization.md | 4 ++++ 1 file changed, 4 insertions(+) mode change 100644 => 100755 docs/source/en/main_classes/quantization.md diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md old mode 100644 new mode 100755 index 91de5fc8a33c..f1e2acdcfe48 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide. ## HfQuantizer [[autodoc]] quantizers.base.HfQuantizer + +## HqqConfig + +[[autodoc]] HqqConfig From 9a061e562b858f6e457f5496930dc634557a2a17 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:34:33 +0000 Subject: [PATCH 29/73] make style --- src/transformers/__init__.py | 2 +- src/transformers/integrations/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 60a43bffc807..4603bd3be17b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1133,8 +1133,8 @@ "BitsAndBytesConfig", "EetqConfig", "GPTQConfig", - "QuantoConfig", "HqqConfig", + "QuantoConfig", ], } diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index a9042ce6d0af..5a90db22ad7d 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -18,7 +18,6 @@ _import_structure = { "aqlm": ["replace_with_aqlm_linear"], - "hqq": ["prepare_for_hqq_linear"], "awq": [ "fuse_awq_modules", "post_init_awq_exllama_modules", @@ -44,6 +43,7 @@ "unset_hf_deepspeed_config", ], "eetq": ["replace_with_eetq_linear"], + "hqq": ["prepare_for_hqq_linear"], "integration_utils": [ "INTEGRATION_TO_CALLBACK", "AzureMLCallback", From 86122823c5e25ea3d30b2955fc0c6f1a893fea24 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:50:12 +0000 Subject: [PATCH 30/73] __init__ fix --- src/transformers/__init__.py | 1 + src/transformers/integrations/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4603bd3be17b..0e83c1f90c29 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6099,6 +6099,7 @@ EetqConfig, GPTQConfig, QuantoConfig, + HqqConfig, ) try: diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 5a90db22ad7d..69fb0e3259b1 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -114,6 +114,7 @@ unset_hf_deepspeed_config, ) from .eetq import replace_with_eetq_linear + from .hqq import prepare_for_hqq_linear from .integration_utils import ( INTEGRATION_TO_CALLBACK, AzureMLCallback, From b78679322f3529a65ada6a01662d40f61513d767 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 10:55:10 +0000 Subject: [PATCH 31/73] ruff __init__ --- src/transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0e83c1f90c29..9b2709abe8c8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6098,8 +6098,8 @@ BitsAndBytesConfig, EetqConfig, GPTQConfig, - QuantoConfig, HqqConfig, + QuantoConfig, ) try: From e7ba7170e6d9b266e2635b7d7e3a5797834eca22 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:07:20 +0000 Subject: [PATCH 32/73] skip_modules list --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 7f2a4672fba1..e93d86ce8e10 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -202,7 +202,7 @@ class HqqConfig(QuantizationConfigMixin): Axis along which grouping is performed. dynamic_config ('dict', defaults to None): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (List[str]): + skip_modules (`List[str]`): List of nn.Linear layers to skip. show_progress ('bool', defaults to True): Show tqdm quantization progress for each shard. From 3a38f2109a934f7f39e34b6fd1bdc78ddc227e5a Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:15:59 +0000 Subject: [PATCH 33/73] hqqconfig format fix --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e93d86ce8e10..9811d13c01e6 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -198,7 +198,7 @@ class HqqConfig(QuantizationConfigMixin): Offload the meta-data on the CPU. view_as_float (`bool`, defaults to False): View the quantized weight as float (used in distributed training) - int (`axis`, defaults to 0. Supported: 0, 1): + axis (`int`, defaults to 0. Supported: 0, 1): Axis along which grouping is performed. dynamic_config ('dict', defaults to None): Parameters for dynamic configuration. The key is the name tag of the layer. From 9eee21314bd73d7859bb07120c8fca439cf797c2 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:31:49 +0000 Subject: [PATCH 34/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 9811d13c01e6..e28e2283955d 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -184,29 +184,30 @@ def update(self, **kwargs): @dataclass class HqqConfig(QuantizationConfigMixin): """ - Main HqqConfig. + This is wrapper around hqq's BaseQuantizeConfig. + Args: - nbits (`int`, defaults to 4. Supported: 8, 4, 3, 2, 1): - Number of bits. - group_size (`int`, defaults to 64. Supported: any value that is divisble by weight.shape[axis]): - Group-size value. - quant_zero (`bool`, defaults to False): - Quantize the zero-point. - quant_scale (`bool`, defaults to False): - Quantize the scaling. - offload_meta (`bool`, defaults to False): - Offload the meta-data on the CPU. - view_as_float (`bool`, defaults to False): - View the quantized weight as float (used in distributed training) - axis (`int`, defaults to 0. Supported: 0, 1): - Axis along which grouping is performed. - dynamic_config ('dict', defaults to None): + nbits (`int`, defaults to 4): + Number of bits. Supported values are (8, 4, 3, 2, 1). + group_size (`int`, defaults to 64): + Group-size value. Supported values are any value that is divisble by weight.shape[axis]). + quant_zero (`bool`, defaults to `False`): + Quantize the zero-point if set to True. + quant_scale (`bool`, defaults to `False`): + Quantize the scaling if set to True. + offload_meta (`bool`, defaults to `False`): + Offload the meta-data on the CPU if set to True. + view_as_float (`bool`, defaults to `False`): + View the quantized weight as float (used in distributed training) if set to True. + axis (`int`, defaults to 0): + Axis along which grouping is performed. Supported values are 0 or 1. + dynamic_config ('dict', defaults to `None`): Parameters for dynamic configuration. The key is the name tag of the layer. skip_modules (`List[str]`): List of nn.Linear layers to skip. - show_progress ('bool', defaults to True): + show_progress ('bool', defaults to `True`): Show tqdm quantization progress for each shard. - kwargs (`Dict[str, Any]`): + kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. """ From 03cc8e6c010e971e01604065b7812c027323e624 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:42:35 +0000 Subject: [PATCH 35/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e28e2283955d..9442ec133500 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -201,9 +201,9 @@ class HqqConfig(QuantizationConfigMixin): View the quantized weight as float (used in distributed training) if set to True. axis (`int`, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. - dynamic_config ('dict', defaults to `None`): + dynamic_config ('Optional[dict]', defaults to `None`): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (`List[str]`): + skip_modules (`List[str]`, defaults to `["lm_head"]`): List of nn.Linear layers to skip. show_progress ('bool', defaults to `True`): Show tqdm quantization progress for each shard. @@ -221,8 +221,8 @@ def __init__( view_as_float: bool = False, axis: int = 0, dynamic_config: Optional[dict] = None, - skip_modules=["lm_head"], - show_progress=True, + skip_modules: List[str] = ["lm_head"], + show_progress: bool = True, **kwargs, ): if is_hqq_available(): From 96bd141b9b9c74bc00f1a5e3c9288289e0a30c25 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:45:50 +0000 Subject: [PATCH 36/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 9442ec133500..bfefd8b654ad 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -191,17 +191,17 @@ class HqqConfig(QuantizationConfigMixin): Number of bits. Supported values are (8, 4, 3, 2, 1). group_size (`int`, defaults to 64): Group-size value. Supported values are any value that is divisble by weight.shape[axis]). - quant_zero (`bool`, defaults to `False`): + quant_zero (`bool`, defaults to `True`): Quantize the zero-point if set to True. quant_scale (`bool`, defaults to `False`): Quantize the scaling if set to True. offload_meta (`bool`, defaults to `False`): - Offload the meta-data on the CPU if set to True. + Offload the meta-data to the CPU if set to True. view_as_float (`bool`, defaults to `False`): View the quantized weight as float (used in distributed training) if set to True. axis (`int`, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. - dynamic_config ('Optional[dict]', defaults to `None`): + dynamic_config (dict, *optional*, defaults to `None`): Parameters for dynamic configuration. The key is the name tag of the layer. skip_modules (`List[str]`, defaults to `["lm_head"]`): List of nn.Linear layers to skip. From 713d2261b61416a5816a3f0e520e97eada5d4004 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 11:47:59 +0000 Subject: [PATCH 37/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bfefd8b654ad..954c387e4433 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -203,7 +203,7 @@ class HqqConfig(QuantizationConfigMixin): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*, defaults to `None`): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (`List[str]`, defaults to `["lm_head"]`): + skip_modules (`List[str]`, *optional*, defaults to `["lm_head"]`): List of nn.Linear layers to skip. show_progress ('bool', defaults to `True`): Show tqdm quantization progress for each shard. From dad9a60d85edf8cba7acacc21e7e30b9657839d8 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 12:00:00 +0000 Subject: [PATCH 38/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 954c387e4433..d6e0640b521b 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -201,9 +201,9 @@ class HqqConfig(QuantizationConfigMixin): View the quantized weight as float (used in distributed training) if set to True. axis (`int`, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. - dynamic_config (dict, *optional*, defaults to `None`): + dynamic_config (dict, *optional*): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (`List[str]`, *optional*, defaults to `["lm_head"]`): + skip_modules (`List[str]`, *optional*): List of nn.Linear layers to skip. show_progress ('bool', defaults to `True`): Show tqdm quantization progress for each shard. From 67c0985d98912e6b7eea3c916978fb847ab52582 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 12:14:40 +0000 Subject: [PATCH 39/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d6e0640b521b..dbfce5ca4114 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -205,7 +205,7 @@ class HqqConfig(QuantizationConfigMixin): Parameters for dynamic configuration. The key is the name tag of the layer. skip_modules (`List[str]`, *optional*): List of nn.Linear layers to skip. - show_progress ('bool', defaults to `True`): + show_progress ('bool', *optional*, defaults to `True`): Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. @@ -221,8 +221,8 @@ def __init__( view_as_float: bool = False, axis: int = 0, dynamic_config: Optional[dict] = None, - skip_modules: List[str] = ["lm_head"], - show_progress: bool = True, + skip_modules: Optional[List[str]] = ["lm_head"], + show_progress: Optional[bool] = True, **kwargs, ): if is_hqq_available(): From 94c393a853174b784229bf29a789150d1aae17d9 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 12:19:00 +0000 Subject: [PATCH 40/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index dbfce5ca4114..cf8ddd1fd5a2 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -187,25 +187,25 @@ class HqqConfig(QuantizationConfigMixin): This is wrapper around hqq's BaseQuantizeConfig. Args: - nbits (`int`, defaults to 4): + nbits (`int`, *optional*, defaults to 4): Number of bits. Supported values are (8, 4, 3, 2, 1). - group_size (`int`, defaults to 64): + group_size (`int`, *optional*, defaults to 64): Group-size value. Supported values are any value that is divisble by weight.shape[axis]). - quant_zero (`bool`, defaults to `True`): + quant_zero (`bool`, *optional*, defaults to `True`): Quantize the zero-point if set to True. - quant_scale (`bool`, defaults to `False`): + quant_scale (`bool`, *optional*, defaults to `False`): Quantize the scaling if set to True. - offload_meta (`bool`, defaults to `False`): + offload_meta (`bool`, *optional*, defaults to `False`): Offload the meta-data to the CPU if set to True. - view_as_float (`bool`, defaults to `False`): + view_as_float (`bool`, *optional*, defaults to `False`): View the quantized weight as float (used in distributed training) if set to True. - axis (`int`, defaults to 0): + axis (`int`, defaults to 0, *optional*, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*): Parameters for dynamic configuration. The key is the name tag of the layer. - skip_modules (`List[str]`, *optional*): + skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): List of nn.Linear layers to skip. - show_progress ('bool', *optional*, defaults to `True`): + show_progress ('bool', defaults to `True`, *optional*, defaults to `True`): Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. From 35fc9f50150e386cc526d778e4d96bc5799cabda Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 12:38:05 +0000 Subject: [PATCH 41/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index cf8ddd1fd5a2..e5932494f002 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -199,13 +199,13 @@ class HqqConfig(QuantizationConfigMixin): Offload the meta-data to the CPU if set to True. view_as_float (`bool`, *optional*, defaults to `False`): View the quantized weight as float (used in distributed training) if set to True. - axis (`int`, defaults to 0, *optional*, defaults to 0): + axis (`int`, *optional*, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*): Parameters for dynamic configuration. The key is the name tag of the layer. skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): List of nn.Linear layers to skip. - show_progress ('bool', defaults to `True`, *optional*, defaults to `True`): + show_progress ('bool', *optional*, defaults to `True`): Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. @@ -221,10 +221,11 @@ def __init__( view_as_float: bool = False, axis: int = 0, dynamic_config: Optional[dict] = None, - skip_modules: Optional[List[str]] = ["lm_head"], - show_progress: Optional[bool] = True, + skip_modules: List[str] = ["lm_head"], + show_progress: bool = True, **kwargs, ): + if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig From 06f649786bcf9951cf6ae4a63d59cf17382adc6f Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 29 Apr 2024 12:45:59 +0000 Subject: [PATCH 42/73] hqqconfig doc fix --- src/transformers/utils/quantization_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e5932494f002..92b5ed81a32e 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -225,7 +225,6 @@ def __init__( show_progress: bool = True, **kwargs, ): - if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig From 25fde9c7d98e05e9ead89b984798222d7cc8f66b Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 30 Apr 2024 07:39:42 +0000 Subject: [PATCH 43/73] test_hqq.py remove mistral comment --- tests/quantization/hqq/test_hqq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 275c2413219d..6dbbdb405686 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -82,7 +82,6 @@ def check_forward(test_module, model, batch_size=1, context_size=1024): model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -# model_id ="mistralai/Mistral-7B-Instruct-v0.2" @require_torch_gpu From ee50516c34b91c4aafb18646b0e901fab854b61f Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 30 Apr 2024 07:42:52 +0000 Subject: [PATCH 44/73] remove self.using_multi_gpu is False --- src/transformers/quantizers/quantizer_hqq.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 348f4d10e8ae..66ba7d6c8e68 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -81,15 +81,14 @@ def validate_environment(self, *args, **kwargs): self.torch_dtype = kwargs.get("torch_dtype", torch.float16) device_map = kwargs.get("device_map", None) - if self.using_multi_gpu is False: - if isinstance(device_map, dict): - if "cpu" in device_map.values() or "disk" in device_map.values(): - raise ValueError( - "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device." - " This is not supported. Please remove the CPU or disk device from the device_map." - ) - else: - self.using_multi_gpu = len(set(device_map.values())) > 1 + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + raise ValueError( + "You are attempting to use an HQQ model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + else: + self.using_multi_gpu = len(set(device_map.values())) > 1 def check_quantized_param( self, From 01d798a4a93ddaf84e36337ba30e4dd7488064c1 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 30 Apr 2024 07:53:39 +0000 Subject: [PATCH 45/73] torch_dtype default val set and logger.info --- src/transformers/quantizers/quantizer_hqq.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 66ba7d6c8e68..450ecaa2233d 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -78,7 +78,11 @@ def validate_environment(self, *args, **kwargs): raise RuntimeError("No GPU found. A GPU is needed for quantization.") if self.torch_dtype is None: - self.torch_dtype = kwargs.get("torch_dtype", torch.float16) + if "torch_dtype" in kwargs: + self.torch_dtype = kwargs["torch_dtype"] + else: + self.torch_dtype = torch.float32 + logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.") device_map = kwargs.get("device_map", None) if isinstance(device_map, dict): From a909ca8a38b9a2836fb960aaae0dba035e9cd0c1 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:15:06 +0000 Subject: [PATCH 46/73] hqq.py isinstance fix --- src/transformers/integrations/hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index f1b38050e8b2..621963d761a9 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -42,7 +42,7 @@ def get_linear_tags(model): linear_tags = set() for name, module in model.named_modules(): - if type(module) in [torch.nn.Linear, HQQLinear]: + if isinstance(module, (torch.nn.Linear, HQQLinear)): linear_tags.add(name_to_linear_tag(name)) return list(linear_tags) From c466c89af1738a3e356ca21f0480af54c72ed586 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:16:27 +0000 Subject: [PATCH 47/73] remove torch=None --- src/transformers/integrations/hqq.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 621963d761a9..10a6d06a3f9f 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -18,8 +18,6 @@ if is_torch_available(): import torch -else: - torch = None logger = logging.get_logger(__name__) From d522fed90c98d1c8507a0320282d846f7febe77a Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:24:15 +0000 Subject: [PATCH 48/73] torch_device test_hqq --- tests/quantization/hqq/test_hqq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 6dbbdb405686..a0df864b473a 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -109,7 +109,7 @@ def test_fp16_quantized_model(self): Simple LLM model testing fp16 """ compute_dtype = torch.float16 - device = "cuda:0" + torch_device = "cuda:0" cache_dir = None quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) @@ -118,7 +118,7 @@ def test_fp16_quantized_model(self): model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, - device=device, + device=torch_device, cache_dir=cache_dir, ) @@ -130,7 +130,7 @@ def test_bfp16_quantized_model_with_offloading(self): Simple LLM model testing bfp16 with meta-data offloading """ compute_dtype = torch.bfloat16 - device = "cuda:0" + torch_device = "cuda:0" cache_dir = None q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} @@ -151,7 +151,7 @@ def test_bfp16_quantized_model_with_offloading(self): model_id=model_id, quant_config=quant_config, compute_dtype=compute_dtype, - device=device, + device=torch_device, cache_dir=cache_dir, ) From a09e90ffa9138dc748e847982c3023d6123cafc4 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:27:31 +0000 Subject: [PATCH 49/73] rename test_hqq --- tests/quantization/hqq/test_hqq.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index a0df864b473a..22e4d505fe7f 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -55,8 +55,8 @@ def cleanup(): def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024): # Test HQQ layer - W_r = hqq_layer.dequantize() - x = ( + W_dequant = hqq_layer.dequantize() # Reconstructed weights + inputs = ( torch.randn( (batch_size, context_size, hqq_layer.meta["shape"][1]), device=hqq_layer.device, @@ -65,10 +65,10 @@ def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024): / 10.0 ) with torch.no_grad(): - y = hqq_layer(x) - test_module.assertEqual(y.shape[-1], W_r.shape[0]) - test_module.assertEqual(y.dtype, hqq_layer.compute_dtype) - del W_r, x, y + outputs = hqq_layer(inputs) + test_module.assertEqual(outputs.shape[-1], W_dequant.shape[0]) + test_module.assertEqual(outputs.dtype, hqq_layer.compute_dtype) + del W_dequant, inputs, outputs cleanup() From 5bdf40f4217e531c126c829ee0821a7b237d3ae1 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:29:00 +0000 Subject: [PATCH 50/73] MODEL_ID in test_hqq --- tests/quantization/hqq/test_hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 22e4d505fe7f..89323be3f0d6 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -81,7 +81,7 @@ def check_forward(test_module, model, batch_size=1, context_size=1024): cleanup() -model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" @require_torch_gpu From e693d4731c6053af9834c02f64c29426a357e060 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:31:32 +0000 Subject: [PATCH 51/73] quantizer_hqq setattr fix --- src/transformers/quantizers/quantizer_hqq.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 450ecaa2233d..3980aa967732 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -158,11 +158,7 @@ def create_quantized_param( if self.using_multi_gpu: hqq_layer = self._patch_layer_for_multigpu(hqq_layer) - setattr( - parent_module, - node, - hqq_layer, - ) + setattr(parent_module, node, hqq_layer) else: module = module.to(dtype=self.torch_dtype, device=target_device) From f5cabe580df21588ab8cd89e5797a3258f4572c5 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:33:14 +0000 Subject: [PATCH 52/73] quantizer_hqq typo fix --- src/transformers/quantizers/quantizer_hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 3980aa967732..aeebcb33a824 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -47,7 +47,7 @@ class HQQHfQuantizer(HfQuantizer): """ HQQ quantizer base HF class. nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading(). - The actually quantization and offloading to the GPU is done in check_quantized_param(). + The actual quantization and offloading to the GPU is done in check_quantized_param(). self.show_progress (bool) is used to show quantization progress in each shard. """ From 5ede086e310028686b6be31ee172ae72d44b4758 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:35:18 +0000 Subject: [PATCH 53/73] imports quantizer_hqq.py --- src/transformers/quantizers/quantizer_hqq.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index aeebcb33a824..2295e8a8345a 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -14,16 +14,15 @@ from typing import TYPE_CHECKING, Any, Dict, List +from ..integrations import prepare_for_hqq_linear +from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging from .base import HfQuantizer +from .quantizers_utils import get_module_from_name if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..integrations import prepare_for_hqq_linear -from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging -from .quantizers_utils import get_module_from_name - if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module From c86000bc687efd2106c236eefc8cb8fee1e57918 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:40:28 +0000 Subject: [PATCH 54/73] isinstance quantizer_hqq --- src/transformers/quantizers/quantizer_hqq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 2295e8a8345a..a44a470d6ded 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -103,7 +103,7 @@ def check_quantized_param( ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - return type(module) is torch.nn.Linear + return isinstance(module, torch.nn.Linear) def create_quantized_param( self, From 7d3e0839d3162a501053fb332c00619a586931c2 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:42:05 +0000 Subject: [PATCH 55/73] hqq_layer.bias reformat quantizer_hqq --- src/transformers/quantizers/quantizer_hqq.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index a44a470d6ded..6a203b51619e 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -150,9 +150,8 @@ def create_quantized_param( del_orig=True, ) - if hqq_layer.bias is not None: - if isinstance(hqq_layer.bias, torch.Tensor): - hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) if self.using_multi_gpu: hqq_layer = self._patch_layer_for_multigpu(hqq_layer) From 082dfea5efab9e852a5fc184d31155a075604ca8 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:43:34 +0000 Subject: [PATCH 56/73] Step 2 as comment in quantizer_hqq --- src/transformers/quantizers/quantizer_hqq.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 6a203b51619e..e6f177fcba5a 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -136,10 +136,8 @@ def create_quantized_param( for key in module_state_dict: setattr(module, key, torch.nn.Parameter(module_state_dict[key])) - """ - Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module - directly doesn't work. - """ + # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module + # directly doesn't work. if hasattr(module, "quant_config"): hqq_layer = HQQLinear( From 667f1adb12ac554fc75ff37ef53972e179c13ebd Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:48:01 +0000 Subject: [PATCH 57/73] prepare_for_hqq_linear() comment --- src/transformers/quantizers/quantizer_hqq.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index e6f177fcba5a..5acb4bbd126f 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -182,11 +182,10 @@ def _process_model_before_weight_loading( keep_in_fp32_modules: List[str] = [], **kwargs, ): - # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear > HQQLinear in create_quantized_param() + # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). + # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) - # model.config.quantization_config is done inside prepare_for_hqq_linear - def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_hqq_quantized = True model.is_hqq_serializable = self.is_serializable From e0cd78463cc8c28c7b03f43dfd8d8f7d80b2b310 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:50:26 +0000 Subject: [PATCH 58/73] keep_in_fp32_modules fix --- src/transformers/quantizers/quantizer_hqq.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 5acb4bbd126f..8f11613dfd33 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -179,9 +179,11 @@ def _process_model_before_weight_loading( self, model: "PreTrainedModel", device_map, - keep_in_fp32_modules: List[str] = [], + keep_in_fp32_modules: List[str] = None, **kwargs, ): + keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else [] + # Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param(). # prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config) model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config) From 5d3b504e4897a299cc91bdf54de44d87fea07b6c Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 08:58:54 +0000 Subject: [PATCH 59/73] HqqHfQuantizer reformat --- src/transformers/quantizers/auto.py | 4 ++-- src/transformers/quantizers/quantizer_hqq.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index b42d2337f4aa..2c65afa77e28 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -32,7 +32,7 @@ from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_gptq import GptqHfQuantizer -from .quantizer_hqq import HQQHfQuantizer +from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -44,7 +44,7 @@ "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, - "hqq": HQQHfQuantizer, + "hqq": HqqHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 8f11613dfd33..bb00b368eb0d 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -42,7 +42,7 @@ def find_parent(model, name): return parent -class HQQHfQuantizer(HfQuantizer): +class HqqHfQuantizer(HfQuantizer): """ HQQ quantizer base HF class. nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading(). From cc1961cbb4e3a4007224f199f065a8d4ccb4676b Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:01:55 +0000 Subject: [PATCH 60/73] quantization.md hqqconfig --- docs/source/en/quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index f79ae9540edc..e8ce34c442d6 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -755,7 +755,7 @@ For installation, we recommend you use the following approach to get the latest pip install hqq ``` -To quantize a model, you need to create an ```HqqConfig``` as follows: +To quantize a model, you need to create an [`HqqConfig`] as follows: ``` Python from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig From 9aa9e15ac0bcb94e4a3ffd08d22ac62280f0d5b9 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:02:54 +0000 Subject: [PATCH 61/73] quantization.md model example reformat --- docs/source/en/quantization.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index e8ce34c442d6..8c5f43e88845 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -779,7 +779,12 @@ quant_config = HqqConfig(dynamic_config={ Then you simply quantize the model as follows ``` Python -model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda", quantization_config=quant_config) +model = transformers.AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + device_map="cuda", + quantization_config=quant_config +) ``` ### Optimized Runtime HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training. From 9273e21d01f2901c4da50152a99d3c14e0ad2044 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:04:08 +0000 Subject: [PATCH 62/73] quantization.md # space --- docs/source/en/quantization.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index 8c5f43e88845..7d5e39e2d9fc 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -759,10 +759,10 @@ To quantize a model, you need to create an [`HqqConfig`] as follows: ``` Python from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig -#Linear layers will use the same quantization config +# Linear layers will use the same quantization config quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default -#Each type of linear layer (referred to as linear tag) will use different quantization parameters +# Each type of linear layer (referred to as linear tag) will use different quantization parameters q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} quant_config = HqqConfig(dynamic_config={ From f29e7a4e4c9809c4ddb5926a1da6fcb4c9c4c76f Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:05:16 +0000 Subject: [PATCH 63/73] quantization.md space }) --- docs/source/en/quantization.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index 7d5e39e2d9fc..60b1ec62e321 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -759,7 +759,7 @@ To quantize a model, you need to create an [`HqqConfig`] as follows: ``` Python from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig -# Linear layers will use the same quantization config +# })Linear layers will use the same quantization config quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default # Each type of linear layer (referred to as linear tag) will use different quantization parameters @@ -774,7 +774,7 @@ quant_config = HqqConfig(dynamic_config={ 'mlp.gate_proj':q3_config, 'mlp.up_proj' :q3_config, 'mlp.down_proj':q3_config, - }) +}) ``` Then you simply quantize the model as follows From 5168852dea30e249cfead61753186e259142c32d Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:05:51 +0000 Subject: [PATCH 64/73] quantization.md space }) --- docs/source/en/quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index 60b1ec62e321..ea1083fde057 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -759,7 +759,7 @@ To quantize a model, you need to create an [`HqqConfig`] as follows: ``` Python from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig -# })Linear layers will use the same quantization config +# Linear layers will use the same quantization config quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default # Each type of linear layer (referred to as linear tag) will use different quantization parameters From 0dfe0806ed484ed2170948f93e5d902ccdd793ab Mon Sep 17 00:00:00 2001 From: mobicham <37179323+mobicham@users.noreply.github.com> Date: Thu, 2 May 2024 11:08:56 +0200 Subject: [PATCH 65/73] quantization_config fix doc Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 92b5ed81a32e..fbcbb13bb5f7 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -192,19 +192,19 @@ class HqqConfig(QuantizationConfigMixin): group_size (`int`, *optional*, defaults to 64): Group-size value. Supported values are any value that is divisble by weight.shape[axis]). quant_zero (`bool`, *optional*, defaults to `True`): - Quantize the zero-point if set to True. + Quantize the zero-point if set to `True`. quant_scale (`bool`, *optional*, defaults to `False`): - Quantize the scaling if set to True. + Quantize the scaling if set to `True`. offload_meta (`bool`, *optional*, defaults to `False`): - Offload the meta-data to the CPU if set to True. + Offload the meta-data to the CPU if set to `True`. view_as_float (`bool`, *optional*, defaults to `False`): - View the quantized weight as float (used in distributed training) if set to True. + View the quantized weight as float (used in distributed training) if set to `True`. axis (`int`, *optional*, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*): Parameters for dynamic configuration. The key is the name tag of the layer. skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): - List of nn.Linear layers to skip. + List of `nn.Linear` layers to skip. show_progress ('bool', *optional*, defaults to `True`): Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`, *optional*): From 29340526f159aba8f1e7da601cefe267a170abd6 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:12:30 +0000 Subject: [PATCH 66/73] axis value check in quantization_config --- src/transformers/utils/quantization_config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index fbcbb13bb5f7..3d505e4ddac7 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -228,6 +228,10 @@ def __init__( if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + if axis not in [0, 1]: + raise ValueError("Invalid axis value. Only 0 and 1 are allowed.") + + if dynamic_config is not None: self.quant_config = {} for key in dynamic_config: From bc7cf4ee28e858c50156a58483fdab38d1849565 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:13:33 +0000 Subject: [PATCH 67/73] format --- src/transformers/utils/quantization_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 3d505e4ddac7..da433953315c 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -231,7 +231,6 @@ def __init__( if axis not in [0, 1]: raise ValueError("Invalid axis value. Only 0 and 1 are allowed.") - if dynamic_config is not None: self.quant_config = {} for key in dynamic_config: From d33f944af372b438efd0414b02c85dc38fb5ef44 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:16:04 +0000 Subject: [PATCH 68/73] dynamic config explanation --- src/transformers/utils/quantization_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index da433953315c..494ccbeb6384 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -202,7 +202,8 @@ class HqqConfig(QuantizationConfigMixin): axis (`int`, *optional*, defaults to 0): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*): - Parameters for dynamic configuration. The key is the name tag of the layer. + Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config. + If set, each layer specified by its id will use its dedicated quantization configuration. skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): List of `nn.Linear` layers to skip. show_progress ('bool', *optional*, defaults to `True`): From 3522f0a672a31f9f3697f2814ec4574ae39f9a34 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:28:18 +0000 Subject: [PATCH 69/73] quant config method in quantization.md --- docs/source/en/quantization.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index ea1083fde057..ae4f44f6b800 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -755,14 +755,16 @@ For installation, we recommend you use the following approach to get the latest pip install hqq ``` -To quantize a model, you need to create an [`HqqConfig`] as follows: +To quantize a model, you need to create an [`HqqConfig`]. There are two ways of doing it: ``` Python from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig -# Linear layers will use the same quantization config +# Method 1: all linear layers will use the same quantization config quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default +``` -# Each type of linear layer (referred to as linear tag) will use different quantization parameters +``` Python +# Method 2: each linear layer with the same tag will use a dedicated quantization config q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} quant_config = HqqConfig(dynamic_config={ @@ -777,6 +779,9 @@ quant_config = HqqConfig(dynamic_config={ }) ``` +The second approach is especially interesting for quantizing Mixture-of-Experts (MoEs) because the experts are less affected by lower quantization settings. + + Then you simply quantize the model as follows ``` Python model = transformers.AutoModelForCausalLM.from_pretrained( From cc14c2118ed359e600d7c6758200c840284a1056 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:40:08 +0000 Subject: [PATCH 70/73] remove shard-level progress --- src/transformers/modeling_utils.py | 9 +-------- src/transformers/quantizers/quantizer_hqq.py | 2 -- src/transformers/utils/quantization_config.py | 2 -- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4eb3ebea7df9..c96550c248f3 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -38,7 +38,6 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss, Identity from torch.utils.checkpoint import checkpoint -from tqdm import tqdm as tqdm_lib from .activations import get_activation from .configuration_utils import PretrainedConfig @@ -809,13 +808,7 @@ def _load_state_dict_into_meta_model( for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) - # Show shard-level progress. Useful to monitor quantization progress - quant_show_progress = False - if hf_quantizer is not None: - if hasattr(hf_quantizer, "show_progress"): - quant_show_progress = hf_quantizer.show_progress - - for param_name, param in tqdm_lib(state_dict.items(), disable=not quant_show_progress): + for param_name, param in state_dict.items(): # First part of the test is always true as load_state_dict_keys always contains state_dict keys. if param_name not in loaded_state_dict_keys or param_name not in expected_keys: continue diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index bb00b368eb0d..dd58c2c1bc5a 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -47,7 +47,6 @@ class HqqHfQuantizer(HfQuantizer): HQQ quantizer base HF class. nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading(). The actual quantization and offloading to the GPU is done in check_quantized_param(). - self.show_progress (bool) is used to show quantization progress in each shard. """ use_keep_in_fp32_modules = False @@ -57,7 +56,6 @@ class HqqHfQuantizer(HfQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) - self.show_progress = quantization_config.show_progress self.torch_dtype = None self.using_multi_gpu = False diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 494ccbeb6384..eacf7867106b 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -223,7 +223,6 @@ def __init__( axis: int = 0, dynamic_config: Optional[dict] = None, skip_modules: List[str] = ["lm_head"], - show_progress: bool = True, **kwargs, ): if is_hqq_available(): @@ -251,7 +250,6 @@ def __init__( self.quant_method = QuantizationMethod.HQQ self.skip_modules = skip_modules - self.show_progress = show_progress self.post_init() From 1e81036f0622683808850e555aa8bb5bc3cc5c67 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:42:01 +0000 Subject: [PATCH 71/73] .cuda fix modeling_utils --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c96550c248f3..36a2db46e00b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2657,7 +2657,7 @@ def get_memory_footprint(self, return_buffers=True): @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: - raise ValueError("`.to` is not supported for HQQ-quantized models.") + raise ValueError("`.cuda` is not supported for HQQ-quantized models.") # Checks if the model has been loaded in 8-bit if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: raise ValueError( From ca07f5a3e56ea1af6d695a13dbb4fa45fbbff750 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 09:47:31 +0000 Subject: [PATCH 72/73] test_hqq fixes --- tests/quantization/hqq/test_hqq.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 89323be3f0d6..e4e01f864963 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -22,6 +22,7 @@ require_torch_gpu, require_torch_multi_gpu, slow, + torch_device, ) from transformers.utils import is_hqq_available, is_torch_available @@ -108,18 +109,10 @@ def test_fp16_quantized_model(self): """ Simple LLM model testing fp16 """ - compute_dtype = torch.float16 - torch_device = "cuda:0" - cache_dir = None - quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) hqq_runner = HQQLLMRunner( - model_id=model_id, - quant_config=quant_config, - compute_dtype=compute_dtype, - device=torch_device, - cache_dir=cache_dir, + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device ) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) @@ -129,10 +122,6 @@ def test_bfp16_quantized_model_with_offloading(self): """ Simple LLM model testing bfp16 with meta-data offloading """ - compute_dtype = torch.bfloat16 - torch_device = "cuda:0" - cache_dir = None - q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True} quant_config = HqqConfig( @@ -148,11 +137,7 @@ def test_bfp16_quantized_model_with_offloading(self): ) hqq_runner = HQQLLMRunner( - model_id=model_id, - quant_config=quant_config, - compute_dtype=compute_dtype, - device=torch_device, - cache_dir=cache_dir, + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device ) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) @@ -171,17 +156,11 @@ def test_fp16_quantized_model_multipgpu(self): """ Simple LLM model testing fp16 with multi-gpu """ - compute_dtype = torch.float16 - cache_dir = None quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) hqq_runner = HQQLLMRunner( - model_id=model_id, - quant_config=quant_config, - compute_dtype=compute_dtype, - device="auto", - cache_dir=cache_dir, + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto" ) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) From 3d777ed7fe677228ec4b67c2aaa1c90ff46a57a0 Mon Sep 17 00:00:00 2001 From: mobicham Date: Thu, 2 May 2024 10:00:41 +0000 Subject: [PATCH 73/73] make fix-copies --- src/transformers/utils/quantization_config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index eacf7867106b..f9e503cf862f 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -206,8 +206,6 @@ class HqqConfig(QuantizationConfigMixin): If set, each layer specified by its id will use its dedicated quantization configuration. skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): List of `nn.Linear` layers to skip. - show_progress ('bool', *optional*, defaults to `True`): - Show tqdm quantization progress for each shard. kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. """