Skip to content

Commit

Permalink
Add convertors and transformations for
Browse files Browse the repository at this point in the history
converting autogptq checkpoint to be loadable in
sparseml, credits to @dbogonwicz for a dimension
mismatch bugfix
  • Loading branch information
rahul-tuli committed Jun 5, 2024
1 parent 38a1214 commit 643598e
Show file tree
Hide file tree
Showing 2 changed files with 355 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/sparseml/utils/pytorch/converters/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import shutil
from abc import ABC
from pathlib import Path
from typing import Dict, Union

import torch

from safetensors.torch import save_file
from sparseml.pytorch.model_load.helpers import load_safetensors_state_dict
from sparseml.utils.pytorch.converters.transfomations import autogptq_transformations


StateDictType = Union[Dict[str, torch.Tensor], str, Path]


class BaseConverter(ABC):
@classmethod
def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType:
new_state_dict = copy.copy(state_dict)
for transformation in cls.transformations():
new_state_dict = transformation(new_state_dict)
return new_state_dict

@classmethod
def convert_from_safetensors(cls, filepath: str, save_dir: str = None):
_validate_safetensors_file_path(filepath)

filepath_: Path = Path(filepath)
if not save_dir:
save_dir = "compressed_tensors_model"

save_dir_: Path = Path(save_dir)
save_dir_.mkdir(exist_ok=True)

# transform and save the state_dict

if filepath_.is_dir():
for file in filepath_.glob("*.safetensors"):
state_dict: StateDictType = load_safetensors_state_dict(file)
new_state_dict = cls.translate(state_dict=state_dict)
save_file(new_state_dict, save_path=save_dir_ / file.name)
_move_non_safetensor_files_(filepath_, save_dir_)

elif filepath_.is_file():
state_dict: StateDictType = load_safetensors_state_dict(filepath)
new_state_dict = cls.translate(state_dict=state_dict)
save_file(new_state_dict, save_path=save_dir_ / filepath_.name)

return str(save_dir_)

@classmethod
def transformations(cls):
"""
Returns the list of transformations that are applied in the converter
"""
raise NotImplementedError()


class ExllamaToCompressedTensorConverter(BaseConverter):
@classmethod
def transformations(cls):
return autogptq_transformations()


def _validate_safetensors_file_path(filepath: str):
"""
Given a file path, it is valid if:
- The file exists
- The file is either a single .safetensors file or a
directory containing .safetensors files
:param filepath: A string file path to validate
"""

filepath_: Path = Path(filepath)

if not filepath_.exists():
raise FileNotFoundError(f"File not found: {filepath}")

if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")

if filepath_.is_file() and not filepath_.suffix == ".safetensors":
raise ValueError(f"File must be a .safetensors file: {filepath}")


def _move_non_safetensor_files_(source_dir: Path, dest_dir: Path):
for file in source_dir.glob("*"):
if file.suffix != ".safetensors":
shutil.move(file, dest_dir / file.name)


def local_test():
autogptq_model_path: str = "/network/rahul/tinyllama_1b_test_w4a16"
new_path = ExllamaToCompressedTensorConverter.convert_from_safetensors(
autogptq_model_path, save_dir="local/models/compressed_tensor_equi"
)

from sparseml.transformers import SparseAutoModelForCausalLM

model = SparseAutoModelForCausalLM.from_pretrained(new_path)

# run eval on this model
# results should be same as the autogptq model


local_test()
232 changes: 232 additions & 0 deletions src/sparseml/utils/pytorch/converters/transfomations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import logging
from typing import Callable, Dict

import numpy
import numpy as np
import torch
from torch import Tensor


_LOGGER = logging.getLogger(__name__)

TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]


def _log_call(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_LOGGER.debug("Applying transformation: %s", func.__name__.upper())
return_value = func(*args, **kwargs)
_LOGGER.debug("Transformation: %s complete", func.__name__.upper())
return return_value

return wrapper


def is_gptq_quantization_target(key: str) -> bool:
"""
Assumes self_attn and mlp are the only quantization targets
in model layers of the state_dict.
:param key: The key of the state_dict
:return: True if the key is a quantization target, False otherwise
"""
return "model.layers" in key and ("self_attn" in key or "mlp" in key)


def transform_to_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Transforms the state_dict keys to match with exllama format
The renames include:
- scales -> weight_fake_quant.scale
- qzeros -> weight_fake_quant.zero_point
- qweight -> weight
Note: does not transforms the actual tensor values
:pre-condition: The state_dict should be for a quantized model
:pre-condition: Targets only the weights of the self_attn and mlp nodes
:param state_dict: The quantized state_dict to be transformed
:return: The transformed state_dict
"""

name_map: Dict[str, str] = {
".scales": ".weight_fake_quant.scale",
".qzeros": ".weight_fake_quant.zero_point",
".qweight": ".weight",
}

updated_state_dict = {}
for key, tensor in state_dict.items():
if any(key.endswith(target_suffix := suffix) for suffix in name_map):
updated_key = key.replace(target_suffix, name_map[target_suffix])
updated_state_dict[updated_key] = tensor
else:
updated_state_dict[key] = tensor
return updated_state_dict


def transform_gptq_weights_and_reshape_tensors(
state_dict: Dict[str, Tensor]
) -> Dict[str, Tensor]:
"""
Tranforms weights into their required shapes and types for Exllama format
The transformations include:
- Quantize the weight tensor using the scales, zeros, and g_idx tensors
additonally pack a group of 8 of them into a single 32 bit integer
and rename the tensor to qweight
- Reshape the scales tensor to [1, x] and convert to fp16
- Reshape the zero points tensor to [1, x] of type int32 and fill with zeros
(it is assumed that quantization was symmetric)
:pre-condition: The state_dict should be for a quantized model
:pre-condition: The state_dict should have been transformed to exllama names
:pre-condition: The state_dict should have the bias and g_idx tensors added
:param state_dict: The state_dict to be transformed
:return: The transformed state_dict, with repacked and reshaped tensors
"""

transformed_state_dict: Dict[str, Tensor] = {}

# auxillary dict to store transformed weights
transformed_weights_dict: Dict[str, Tensor] = {}

# quantize qweights before scales, and qzeros
# because the ordering in which tensors are fetched
# is not guaranteed by our implementation
for key, tensor in state_dict.items():
if is_gptq_quantization_target(key) and key.endswith(".qweight"):
# quantize the weight tensor
scales = state_dict[key.replace("qweight", "scales")]
qzeros = state_dict[key.replace("qweight", "qzeros")]
g_idx = state_dict[key.replace("qweight", "g_idx")]

zeros = unpack_zeros(qzeros)
qweight = unpack_int32_into_fp32(
qweight=tensor,
scales=scales,
zeros=zeros,
g_idx=g_idx,
)
assert qweight.dtype == torch.int32
transformed_weights_dict[key] = qweight

# transform scales and zero points
for key, tensor in state_dict.items():
if is_gptq_quantization_target(key) and key.endswith(".scales"):
# scales [x] should be reshaped to [1, x]
# and converted to fp16
scales = tensor.reshape(1, -1).half()
transformed_state_dict[key] = scales
elif is_gptq_quantization_target(key) and key.endswith(".qzeros"):
# zero points [8x] should be reshaped to [1, x]
# of type int32 and filled with zeros (symmetric quantization)
zeros = torch.zeros(tensor.shape[0] // 8, dtype=torch.int32)
transformed_state_dict[key] = zeros.reshape(1, -1)
else:
transformed_state_dict[key] = tensor

# overwrite old weights with the new quantized weights
transformed_state_dict.update(transformed_weights_dict)

# auxillary weights_dict not needed anymore
del transformed_weights_dict

return transformed_state_dict


def unpack_zeros(qzeros):
bits = 4
qzeros = qzeros.numpy().astype(np.uint32)
intzeros = np.zeros((qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32)

i = 0
col = 0
while col < intzeros.shape[1]:
if bits in [4]:
for j in range(i, min(i + (32 // bits), intzeros.shape[1])):
intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF
i += 32 // bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")

intzeros = intzeros.astype(np.int32)
intzeros = torch.from_numpy(intzeros)

return intzeros


def unpack_int32_into_fp32(
qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor
) -> Tensor:
"""
Unpack the quantized weight tensor from 32 bit integers into 4 bit integers,
and then dequantize them using the scales, zeros, and g_idx tensors.
:param qweight: The quantized weight tensor of int32 dtype and shape [x, y]
:param scales: The scales tensor
:param zeros: The zero points tensor
:param g_idx: The group index tensor
:return: The dequantized weight tensor of shape [x, 8y]
"""
bits = 4
qweight = qweight.numpy().astype(numpy.uint32)
intweight = numpy.zeros(
(qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32
)

i = 0
row = 0
while row < intweight.shape[0]:
if bits in [4]:
for j in range(i, min(i + (32 // bits), intweight.shape[0])):
intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF
i += 32 // bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")

intweight = torch.from_numpy(intweight.astype(numpy.int32))
intweight = intweight.t().contiguous()

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
scales = scales.clone().half()

weight = []
infeatures = intweight.shape[1]
for idx in range(infeatures):
weight.append(
(intweight[:, idx].float() * scales[:, g_idx[idx]] - scale_zeros[:, g_idx[idx]])[
:, None
]
)
weight = torch.cat(weight, dim=1)

return weight


def autogptq_transformations():
"""
return: the transformations required to convert and run a autogptq checkpoint
"""
return (
transform_gptq_weights_and_reshape_tensors,
transform_to_exllama_names
)

0 comments on commit 643598e

Please sign in to comment.