Skip to content

Commit

Permalink
LLM: 2bit quantization initial support (#10042)
Browse files Browse the repository at this point in the history
* basis quantize support

* fix new module name

* small update

* and mixed int4 with iq2_xxs

* remove print

* code refactor

* fix style

* meet code review
  • Loading branch information
rnwang04 authored Feb 6, 2024
1 parent 81acd6f commit 96c5d4d
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 22 deletions.
24 changes: 24 additions & 0 deletions python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,30 @@ def ggml_quantize_tensor(
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t


def ggml_quantize_tensor_with_weights(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
qtype: ctypes.c_int,
nrow: ctypes.c_int,
n_per_row: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
) -> int:
return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights)


_lib.ggml_quantize_tensor_with_weights.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.POINTER(ctypes.c_float),
]
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t


def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype)

Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/bigdl/llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
"mixed_fp8": 18, # Mixture of Formats Quantization 8 bits
"fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 19, # fp8 in e5m2 format
"bf16": 20}
"bf16": 20,
"iq2_xxs": 21,
"iq2_xs": 22}

_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,
Expand Down
20 changes: 14 additions & 6 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import transformers
import importlib.util
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger
from .utils import logger, get_cur_qtype_and_imatrix
from typing import Union
import numpy as np
import os
Expand Down Expand Up @@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False):

def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name=''):
cpu_embedding=False, prefix_name='',
imatrix_data=None):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding
Expand Down Expand Up @@ -248,15 +249,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None,
mp_group=mp_group,
)

cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data)
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=qtype).to(device)
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features).to(device)
new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
Expand Down Expand Up @@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name,
convert_shape_only,
cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
Expand Down Expand Up @@ -505,7 +511,8 @@ def _optimize_pre(model):
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False,
lightweight_bmm=False, torch_dtype="auto"):
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
Expand All @@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
)
if not has_been_replaced:
warnings.warn(
Expand Down
33 changes: 25 additions & 8 deletions python/llm/src/bigdl/llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"]


def get_block_size(qtype: str):
Expand All @@ -81,20 +83,20 @@ def get_qk_size(qtype: int):


def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
device=None, convert_shape_only=False):
device=None, convert_shape_only=False,
imatrix: torch.Tensor=None,
in_features: int=None):
QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype)

invalidInputError(tensor.dtype == torch.float,
"Input tensor must be float32")
src = tensor.data.data_ptr()
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
n = tensor.numel()
invalidInputError(n % QK == 0,
"Input tensor size must be multiple of 64")
n = tensor.numel() # all elements
k = tensor.shape[-1]
invalidInputError(k % QK == 0,
"Last dim of input tensor must be multiple of 64")
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
Expand All @@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
if qtype not in [IQ2_XXS, IQ2_XS]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
else:
# quantize with importance matrix
imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
# pass nrow and n_per_row
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
return dst_tensor


Expand Down Expand Up @@ -193,7 +204,9 @@ def __new__(cls,
quantized=False,
_shape=None,
convert_shape_only=False,
qtype=None):
qtype=None,
imatrix=None,
in_features=None):
if data is None:
data = torch.empty(0)

Expand All @@ -203,6 +216,8 @@ def __new__(cls,
self._shape = _shape
self.qtype = qtype
self.convert_shape_only = convert_shape_only
self.imatrix = imatrix
self.in_features = in_features
return self

def ggml_mse(self, w, ggml_qtype, device):
Expand Down Expand Up @@ -255,7 +270,9 @@ def quantize(self, device=None):
else:
w_quantized = ggml_convert_qtype(w, self.qtype,
device=device,
convert_shape_only=self.convert_shape_only)
convert_shape_only=self.convert_shape_only,
imatrix=self.imatrix,
in_features=self.in_features)
self.data = w_quantized
self.quantized = True
self._shape = w.shape
Expand Down
26 changes: 19 additions & 7 deletions python/llm/src/bigdl/llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, \
load_state_dict, \
get_local_shard_files
get_local_shard_files, load_imatrix_data
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.gguf.api import load_gguf_model
Expand Down Expand Up @@ -107,10 +107,10 @@ def from_pretrained(cls,
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4,
``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit
NormalFloat, etc. Relevant low bit optimizations will be applied
to the model.
``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``True``.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
Expand All @@ -121,6 +121,9 @@ def from_pretrained(cls,
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param imatrix: str value, represent filename of importance matrix pretrained on
specific datasets for use with the improved quantization methods recently
added to llama.cpp.
:return: a model instance
"""
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
Expand Down Expand Up @@ -243,6 +246,12 @@ def from_pretrained(cls,
else:
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
if q_k in ["iq2_xxs", "iq2_xs"]:
imatrix_file = kwargs.pop("imatrix", None)
invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
imatrix_data = load_imatrix_data(imatrix_file)
kwargs['imatrix_data'] = imatrix_data
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)

if speculative:
Expand Down Expand Up @@ -285,7 +294,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.")
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
f"mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k]

# In case it needs a second try,
Expand All @@ -299,6 +309,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
cpu_embedding = True
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
quant_config = kwargs.pop("quantization_config", None)
imatrix_data = kwargs.pop("imatrix_data", None)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
awq_config = None
Expand Down Expand Up @@ -359,7 +370,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'))
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data)
model.config.update({"bigdl_transformers_low_bit": q_k})

# enable tie_word_embeddings for MPT
Expand Down
79 changes: 79 additions & 0 deletions python/llm/src/bigdl/llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@
# SOFTWARE.
import os
from transformers.modeling_utils import _add_variant
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from ..utils.common import invalidInputError
from typing import Union
import torch
from torch import nn
import logging
import numpy as np


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,3 +182,79 @@ def get_xpu_device_type(x):
return "pvc"
else:
return "others"


def load_imatrix_data(imatrix_file):
# this function is adapted from https://github.com/ggerganov/llama.cpp/blob/
# c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102
imatrix = open(imatrix_file, 'rb')
n_entries = imatrix.read(4)
n_entries = int.from_bytes(n_entries, 'little')
invalidInputError(n_entries >= 1,
f"failed reading name for entry from {imatrix_file}")
imatrix_data = {}
for i in range(n_entries):
cur_len = imatrix.read(4)
cur_len = int.from_bytes(cur_len, 'little')
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
# original cur_name looks like blk.14.attn_output.weight for llama
# TODO: how to better aligned and generalize
name_list = cur_name.split('.')
layer = name_list[1]
module_name = name_list[2]
if 'ffn' in module_name:
module_name = module_name[4:] # from ffn_gate to gate
elif 'attn' in module_name:
module_name = module_name[5] # from attn_k to k, attn_output to o
module_name = layer + '_' + module_name
ncall = imatrix.read(4)
ncall = int.from_bytes(ncall, 'little')
nval = imatrix.read(4)
nval = int.from_bytes(nval, 'little')
invalidInputError(nval >= 1,
f"failed reading number of values for entry {i}")
byte_data = imatrix.read(4 * nval)
idata = np.frombuffer(byte_data, dtype=np.float32)

if ncall > 0:
idata = idata / ncall
imatrix_data[module_name] = torch.from_numpy(idata).float()

print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.")
return imatrix_data


def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
# For quantization which needs importance matrix
# module name preprocess
# full name maybe model.layers.31.self_attn.o_proj
# TODO: just consider llama/mistral here
# TODO: how to better aligned and generalize
module_name = full_module_name.split('.')
cur_qtype = qtype
if len(module_name) == 5:
layer = module_name[2]
cur_module = module_name[-1][:-5]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name) == 1:
new_module_name = module_name[0]
layer = None
cur_module = None
if imatrix_data is not None and new_module_name in imatrix_data:
cur_imatrix = imatrix_data[new_module_name]
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
cur_qtype = qtype

return cur_qtype, cur_imatrix

0 comments on commit 96c5d4d

Please sign in to comment.