-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add convertors and transformations for
converting autogptq checkpoint to be loadable in sparseml, credits to @dbogonwicz for a dimension mismatch bugfix
- Loading branch information
1 parent
38a1214
commit 643598e
Showing
2 changed files
with
355 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import copy | ||
import os | ||
import shutil | ||
from abc import ABC | ||
from pathlib import Path | ||
from typing import Dict, Union | ||
|
||
import torch | ||
|
||
from safetensors.torch import save_file | ||
from sparseml.pytorch.model_load.helpers import load_safetensors_state_dict | ||
from sparseml.utils.pytorch.converters.transfomations import autogptq_transformations | ||
|
||
|
||
StateDictType = Union[Dict[str, torch.Tensor], str, Path] | ||
|
||
|
||
class BaseConverter(ABC): | ||
@classmethod | ||
def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: | ||
new_state_dict = copy.copy(state_dict) | ||
for transformation in cls.transformations(): | ||
new_state_dict = transformation(new_state_dict) | ||
return new_state_dict | ||
|
||
@classmethod | ||
def convert_from_safetensors(cls, filepath: str, save_dir: str = None): | ||
_validate_safetensors_file_path(filepath) | ||
|
||
filepath_: Path = Path(filepath) | ||
if not save_dir: | ||
save_dir = "compressed_tensors_model" | ||
|
||
save_dir_: Path = Path(save_dir) | ||
save_dir_.mkdir(exist_ok=True) | ||
|
||
# transform and save the state_dict | ||
|
||
if filepath_.is_dir(): | ||
for file in filepath_.glob("*.safetensors"): | ||
state_dict: StateDictType = load_safetensors_state_dict(file) | ||
new_state_dict = cls.translate(state_dict=state_dict) | ||
save_file(new_state_dict, save_path=save_dir_ / file.name) | ||
_move_non_safetensor_files_(filepath_, save_dir_) | ||
|
||
elif filepath_.is_file(): | ||
state_dict: StateDictType = load_safetensors_state_dict(filepath) | ||
new_state_dict = cls.translate(state_dict=state_dict) | ||
save_file(new_state_dict, save_path=save_dir_ / filepath_.name) | ||
|
||
return str(save_dir_) | ||
|
||
@classmethod | ||
def transformations(cls): | ||
""" | ||
Returns the list of transformations that are applied in the converter | ||
""" | ||
raise NotImplementedError() | ||
|
||
|
||
class ExllamaToCompressedTensorConverter(BaseConverter): | ||
@classmethod | ||
def transformations(cls): | ||
return autogptq_transformations() | ||
|
||
|
||
def _validate_safetensors_file_path(filepath: str): | ||
""" | ||
Given a file path, it is valid if: | ||
- The file exists | ||
- The file is either a single .safetensors file or a | ||
directory containing .safetensors files | ||
:param filepath: A string file path to validate | ||
""" | ||
|
||
filepath_: Path = Path(filepath) | ||
|
||
if not filepath_.exists(): | ||
raise FileNotFoundError(f"File not found: {filepath}") | ||
|
||
if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): | ||
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") | ||
|
||
if filepath_.is_file() and not filepath_.suffix == ".safetensors": | ||
raise ValueError(f"File must be a .safetensors file: {filepath}") | ||
|
||
|
||
def _move_non_safetensor_files_(source_dir: Path, dest_dir: Path): | ||
for file in source_dir.glob("*"): | ||
if file.suffix != ".safetensors": | ||
shutil.move(file, dest_dir / file.name) | ||
|
||
|
||
def local_test(): | ||
autogptq_model_path: str = "/network/rahul/tinyllama_1b_test_w4a16" | ||
new_path = ExllamaToCompressedTensorConverter.convert_from_safetensors( | ||
autogptq_model_path, save_dir="local/models/compressed_tensor_equi" | ||
) | ||
|
||
from sparseml.transformers import SparseAutoModelForCausalLM | ||
|
||
model = SparseAutoModelForCausalLM.from_pretrained(new_path) | ||
|
||
# run eval on this model | ||
# results should be same as the autogptq model | ||
|
||
|
||
local_test() |
232 changes: 232 additions & 0 deletions
232
src/sparseml/utils/pytorch/converters/transfomations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import functools | ||
import logging | ||
from typing import Callable, Dict | ||
|
||
import numpy | ||
import numpy as np | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] | ||
|
||
|
||
def _log_call(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
_LOGGER.debug("Applying transformation: %s", func.__name__.upper()) | ||
return_value = func(*args, **kwargs) | ||
_LOGGER.debug("Transformation: %s complete", func.__name__.upper()) | ||
return return_value | ||
|
||
return wrapper | ||
|
||
|
||
def is_gptq_quantization_target(key: str) -> bool: | ||
""" | ||
Assumes self_attn and mlp are the only quantization targets | ||
in model layers of the state_dict. | ||
:param key: The key of the state_dict | ||
:return: True if the key is a quantization target, False otherwise | ||
""" | ||
return "model.layers" in key and ("self_attn" in key or "mlp" in key) | ||
|
||
|
||
def transform_to_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||
""" | ||
Transforms the state_dict keys to match with exllama format | ||
The renames include: | ||
- scales -> weight_fake_quant.scale | ||
- qzeros -> weight_fake_quant.zero_point | ||
- qweight -> weight | ||
Note: does not transforms the actual tensor values | ||
:pre-condition: The state_dict should be for a quantized model | ||
:pre-condition: Targets only the weights of the self_attn and mlp nodes | ||
:param state_dict: The quantized state_dict to be transformed | ||
:return: The transformed state_dict | ||
""" | ||
|
||
name_map: Dict[str, str] = { | ||
".scales": ".weight_fake_quant.scale", | ||
".qzeros": ".weight_fake_quant.zero_point", | ||
".qweight": ".weight", | ||
} | ||
|
||
updated_state_dict = {} | ||
for key, tensor in state_dict.items(): | ||
if any(key.endswith(target_suffix := suffix) for suffix in name_map): | ||
updated_key = key.replace(target_suffix, name_map[target_suffix]) | ||
updated_state_dict[updated_key] = tensor | ||
else: | ||
updated_state_dict[key] = tensor | ||
return updated_state_dict | ||
|
||
|
||
def transform_gptq_weights_and_reshape_tensors( | ||
state_dict: Dict[str, Tensor] | ||
) -> Dict[str, Tensor]: | ||
""" | ||
Tranforms weights into their required shapes and types for Exllama format | ||
The transformations include: | ||
- Quantize the weight tensor using the scales, zeros, and g_idx tensors | ||
additonally pack a group of 8 of them into a single 32 bit integer | ||
and rename the tensor to qweight | ||
- Reshape the scales tensor to [1, x] and convert to fp16 | ||
- Reshape the zero points tensor to [1, x] of type int32 and fill with zeros | ||
(it is assumed that quantization was symmetric) | ||
:pre-condition: The state_dict should be for a quantized model | ||
:pre-condition: The state_dict should have been transformed to exllama names | ||
:pre-condition: The state_dict should have the bias and g_idx tensors added | ||
:param state_dict: The state_dict to be transformed | ||
:return: The transformed state_dict, with repacked and reshaped tensors | ||
""" | ||
|
||
transformed_state_dict: Dict[str, Tensor] = {} | ||
|
||
# auxillary dict to store transformed weights | ||
transformed_weights_dict: Dict[str, Tensor] = {} | ||
|
||
# quantize qweights before scales, and qzeros | ||
# because the ordering in which tensors are fetched | ||
# is not guaranteed by our implementation | ||
for key, tensor in state_dict.items(): | ||
if is_gptq_quantization_target(key) and key.endswith(".qweight"): | ||
# quantize the weight tensor | ||
scales = state_dict[key.replace("qweight", "scales")] | ||
qzeros = state_dict[key.replace("qweight", "qzeros")] | ||
g_idx = state_dict[key.replace("qweight", "g_idx")] | ||
|
||
zeros = unpack_zeros(qzeros) | ||
qweight = unpack_int32_into_fp32( | ||
qweight=tensor, | ||
scales=scales, | ||
zeros=zeros, | ||
g_idx=g_idx, | ||
) | ||
assert qweight.dtype == torch.int32 | ||
transformed_weights_dict[key] = qweight | ||
|
||
# transform scales and zero points | ||
for key, tensor in state_dict.items(): | ||
if is_gptq_quantization_target(key) and key.endswith(".scales"): | ||
# scales [x] should be reshaped to [1, x] | ||
# and converted to fp16 | ||
scales = tensor.reshape(1, -1).half() | ||
transformed_state_dict[key] = scales | ||
elif is_gptq_quantization_target(key) and key.endswith(".qzeros"): | ||
# zero points [8x] should be reshaped to [1, x] | ||
# of type int32 and filled with zeros (symmetric quantization) | ||
zeros = torch.zeros(tensor.shape[0] // 8, dtype=torch.int32) | ||
transformed_state_dict[key] = zeros.reshape(1, -1) | ||
else: | ||
transformed_state_dict[key] = tensor | ||
|
||
# overwrite old weights with the new quantized weights | ||
transformed_state_dict.update(transformed_weights_dict) | ||
|
||
# auxillary weights_dict not needed anymore | ||
del transformed_weights_dict | ||
|
||
return transformed_state_dict | ||
|
||
|
||
def unpack_zeros(qzeros): | ||
bits = 4 | ||
qzeros = qzeros.numpy().astype(np.uint32) | ||
intzeros = np.zeros((qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32) | ||
|
||
i = 0 | ||
col = 0 | ||
while col < intzeros.shape[1]: | ||
if bits in [4]: | ||
for j in range(i, min(i + (32 // bits), intzeros.shape[1])): | ||
intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF | ||
i += 32 // bits | ||
col += 1 | ||
else: | ||
raise NotImplementedError("Only 4 bits are supported.") | ||
|
||
intzeros = intzeros.astype(np.int32) | ||
intzeros = torch.from_numpy(intzeros) | ||
|
||
return intzeros | ||
|
||
|
||
def unpack_int32_into_fp32( | ||
qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor | ||
) -> Tensor: | ||
""" | ||
Unpack the quantized weight tensor from 32 bit integers into 4 bit integers, | ||
and then dequantize them using the scales, zeros, and g_idx tensors. | ||
:param qweight: The quantized weight tensor of int32 dtype and shape [x, y] | ||
:param scales: The scales tensor | ||
:param zeros: The zero points tensor | ||
:param g_idx: The group index tensor | ||
:return: The dequantized weight tensor of shape [x, 8y] | ||
""" | ||
bits = 4 | ||
qweight = qweight.numpy().astype(numpy.uint32) | ||
intweight = numpy.zeros( | ||
(qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32 | ||
) | ||
|
||
i = 0 | ||
row = 0 | ||
while row < intweight.shape[0]: | ||
if bits in [4]: | ||
for j in range(i, min(i + (32 // bits), intweight.shape[0])): | ||
intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF | ||
i += 32 // bits | ||
row += 1 | ||
else: | ||
raise NotImplementedError("Only 4 bits are supported.") | ||
|
||
intweight = torch.from_numpy(intweight.astype(numpy.int32)) | ||
intweight = intweight.t().contiguous() | ||
|
||
scales = scales.t().contiguous() | ||
zeros = zeros.t().contiguous() | ||
scale_zeros = zeros * scales | ||
scales = scales.clone().half() | ||
|
||
weight = [] | ||
infeatures = intweight.shape[1] | ||
for idx in range(infeatures): | ||
weight.append( | ||
(intweight[:, idx].float() * scales[:, g_idx[idx]] - scale_zeros[:, g_idx[idx]])[ | ||
:, None | ||
] | ||
) | ||
weight = torch.cat(weight, dim=1) | ||
|
||
return weight | ||
|
||
|
||
def autogptq_transformations(): | ||
""" | ||
return: the transformations required to convert and run a autogptq checkpoint | ||
""" | ||
return ( | ||
transform_gptq_weights_and_reshape_tensors, | ||
transform_to_exllama_names | ||
) | ||
|