Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: 2bit quantization support #10042

Merged
merged 8 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,30 @@ def ggml_quantize_tensor(
_lib.ggml_quantize_tensor.restype = ctypes.c_size_t


def ggml_quantize_tensor_with_weights(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
qtype: ctypes.c_int,
nrow: ctypes.c_int,
n_per_row: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
) -> int:
return _lib.ggml_quantize_tensor_with_weights(src, dst, qtype, nrow, n_per_row, hist, weights)


_lib.ggml_quantize_tensor_with_weights.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.POINTER(ctypes.c_float),
]
_lib.ggml_quantize_tensor_with_weights.restype = ctypes.c_size_t


def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype)

Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/bigdl/llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
"mixed_fp8": 18, # Mixture of Formats Quantization 8 bits
"fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 19, # fp8 in e5m2 format
"bf16": 20}
"bf16": 20,
"iq2_xxs": 21,
"iq2_xs": 22}

_llama_quantize_type = {"q4_0": 2,
"q4_1": 3,
Expand Down
20 changes: 14 additions & 6 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import transformers
import importlib.util
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from .utils import logger
from .utils import logger, get_cur_qtype_and_imatrix
from typing import Union
import numpy as np
import os
Expand Down Expand Up @@ -190,7 +190,8 @@ def convert_gptq(module, awq=False, llm_awq=False):

def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name=None, convert_shape_only=False,
cpu_embedding=False, prefix_name=''):
cpu_embedding=False, prefix_name='',
imatrix_data=None):
from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
from bigdl.llm.transformers.embedding import LLMEmbedding
Expand Down Expand Up @@ -248,15 +249,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
module.bias is not None,
mp_group=mp_group,
)

cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data)
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=qtype).to(device)
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features).to(device)
new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
Expand Down Expand Up @@ -328,7 +333,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
current_key_name,
convert_shape_only,
cpu_embedding,
prefix_name=prefix_name + '.' + name if prefix_name != '' else name
prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
imatrix_data=imatrix_data
)
has_been_replaced = _flag or has_been_replaced
return model, has_been_replaced
Expand Down Expand Up @@ -505,7 +511,8 @@ def _optimize_pre(model):
def ggml_convert_low_bit(model, qtype, optimize_model=True,
convert_shape_only=False, device="cpu",
modules_to_not_convert=None, cpu_embedding=False,
lightweight_bmm=False, torch_dtype="auto"):
lightweight_bmm=False, torch_dtype="auto",
imatrix_data=None):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
Expand All @@ -517,6 +524,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
None, convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
)
if not has_been_replaced:
warnings.warn(
Expand Down
33 changes: 25 additions & 8 deletions python/llm/src/bigdl/llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"]


def get_block_size(qtype: str):
Expand All @@ -81,20 +83,20 @@ def get_qk_size(qtype: int):


def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
device=None, convert_shape_only=False):
device=None, convert_shape_only=False,
imatrix: torch.Tensor=None,
in_features: int=None):
QK = ggml.ggml_qk_size(qtype)
block_size_in_bytes = ggml.ggml_type_size(qtype)

invalidInputError(tensor.dtype == torch.float,
"Input tensor must be float32")
src = tensor.data.data_ptr()
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
n = tensor.numel()
invalidInputError(n % QK == 0,
"Input tensor size must be multiple of 64")
n = tensor.numel() # all elements
k = tensor.shape[-1]
invalidInputError(k % QK == 0,
"Last dim of input tensor must be multiple of 64")
f"Last dim of input tensor must be multiple of {QK}")

dst_size = (n // QK) * block_size_in_bytes
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
Expand All @@ -103,7 +105,16 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if not convert_shape_only and device != 'meta':
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
hist = (ctypes.c_int64 * 16)()
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
if qtype not in [IQ2_XXS, IQ2_XS]:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist)
else:
# quantize with importance matrix
imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
# pass nrow and n_per_row
ggml.ggml_quantize_tensor_with_weights(src, dst, qtype,
n // in_features, in_features,
hist, imatrix)
return dst_tensor


Expand Down Expand Up @@ -193,7 +204,9 @@ def __new__(cls,
quantized=False,
_shape=None,
convert_shape_only=False,
qtype=None):
qtype=None,
imatrix=None,
in_features=None):
if data is None:
data = torch.empty(0)

Expand All @@ -203,6 +216,8 @@ def __new__(cls,
self._shape = _shape
self.qtype = qtype
self.convert_shape_only = convert_shape_only
self.imatrix = imatrix
self.in_features = in_features
return self

def ggml_mse(self, w, ggml_qtype, device):
Expand Down Expand Up @@ -255,7 +270,9 @@ def quantize(self, device=None):
else:
w_quantized = ggml_convert_qtype(w, self.qtype,
device=device,
convert_shape_only=self.convert_shape_only)
convert_shape_only=self.convert_shape_only,
imatrix=self.imatrix,
in_features=self.in_features)
self.data = w_quantized
self.quantized = True
self._shape = w.shape
Expand Down
26 changes: 19 additions & 7 deletions python/llm/src/bigdl/llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers.configuration_utils import PretrainedConfig
from .utils import extract_local_archive_file, \
load_state_dict, \
get_local_shard_files
get_local_shard_files, load_imatrix_data
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.gguf.api import load_gguf_model
Expand Down Expand Up @@ -107,10 +107,10 @@ def from_pretrained(cls,
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'fp16'`` or ``'bf16'``, ``'sym_int4'`` means symmetric int 4,
``'asym_int4'`` means asymmetric int 4, ``'nf4'`` means 4-bit
NormalFloat, etc. Relevant low bit optimizations will be applied
to the model.
``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
:param optimize_model: boolean value, Whether to further optimize the low_bit llm model.
Default to be ``True``.
:param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
Expand All @@ -121,6 +121,9 @@ def from_pretrained(cls,
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param lightweight_bmm: Whether to replace the torch.bmm ops, may need to set it
to ``True`` when running BigDL-LLM on GPU on Windows. Default to be ``False``.
:param imatrix: str value, represent filename of importance matrix pretrained on
specific datasets for use with the improved quantization methods recently
added to llama.cpp.
:return: a model instance
"""
pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) \
Expand Down Expand Up @@ -243,6 +246,12 @@ def from_pretrained(cls,
else:
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
if q_k in ["iq2_xxs", "iq2_xs"]:
imatrix_file = kwargs.pop("imatrix", None)
invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
imatrix_data = load_imatrix_data(imatrix_file)
kwargs['imatrix_data'] = imatrix_data
model = cls.load_convert(q_k, optimize_model, *args, **kwargs)

if speculative:
Expand Down Expand Up @@ -285,7 +294,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
invalidInputError(q_k in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {q_k}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, mixed_fp4 or mixed_fp8.")
f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16, bf16, iq2_xxs, iq2_xs, "
f"mixed_fp4 or mixed_fp8.")
qtype = ggml_tensor_qtype[q_k]

# In case it needs a second try,
Expand All @@ -299,6 +309,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
cpu_embedding = True
lightweight_bmm = kwargs.pop("lightweight_bmm", False)
quant_config = kwargs.pop("quantization_config", None)
imatrix_data = kwargs.pop("imatrix_data", None)
_args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs)
awq_config = None
Expand Down Expand Up @@ -359,7 +370,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs):
model = ggml_convert_low_bit(model, qtype, optimize_model,
modules_to_not_convert=modules_to_not_convert,
cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
torch_dtype=kwargs.get("torch_dtype", 'auto'))
torch_dtype=kwargs.get("torch_dtype", 'auto'),
imatrix_data=imatrix_data)
model.config.update({"bigdl_transformers_low_bit": q_k})

# enable tie_word_embeddings for MPT
Expand Down
79 changes: 79 additions & 0 deletions python/llm/src/bigdl/llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@
# SOFTWARE.
import os
from transformers.modeling_utils import _add_variant
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from ..utils.common import invalidInputError
from typing import Union
import torch
from torch import nn
import logging
import numpy as np


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -179,3 +182,79 @@ def get_xpu_device_type(x):
return "pvc"
else:
return "others"


def load_imatrix_data(imatrix_file):
# this function is adapted from https://github.com/ggerganov/llama.cpp/blob/
# c82d18e863fcde91b4b1109b1d0c73ea4470c405/examples/quantize/quantize.cpp#L102
imatrix = open(imatrix_file, 'rb')
n_entries = imatrix.read(4)
n_entries = int.from_bytes(n_entries, 'little')
invalidInputError(n_entries >= 1,
f"failed reading name for entry from {imatrix_file}")
imatrix_data = {}
for i in range(n_entries):
cur_len = imatrix.read(4)
cur_len = int.from_bytes(cur_len, 'little')
cur_name = str(imatrix.read(cur_len), encoding='utf-8')
# original cur_name looks like blk.14.attn_output.weight for llama
# TODO: how to better aligned and generalize
name_list = cur_name.split('.')
layer = name_list[1]
module_name = name_list[2]
if 'ffn' in module_name:
module_name = module_name[4:] # from ffn_gate to gate
elif 'attn' in module_name:
module_name = module_name[5] # from attn_k to k, attn_output to o
module_name = layer + '_' + module_name
ncall = imatrix.read(4)
ncall = int.from_bytes(ncall, 'little')
nval = imatrix.read(4)
nval = int.from_bytes(nval, 'little')
invalidInputError(nval >= 1,
f"failed reading number of values for entry {i}")
byte_data = imatrix.read(4 * nval)
idata = np.frombuffer(byte_data, dtype=np.float32)

if ncall > 0:
idata = idata / ncall
imatrix_data[module_name] = torch.from_numpy(idata).float()

print(f"loaded {len(imatrix_data)} importance matrix entries from {imatrix_file}.")
return imatrix_data


def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data):
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
# For quantization which needs importance matrix
# module name preprocess
# full name maybe model.layers.31.self_attn.o_proj
# TODO: just consider llama/mistral here
# TODO: how to better aligned and generalize
module_name = full_module_name.split('.')
cur_qtype = qtype
if len(module_name) == 5:
layer = module_name[2]
cur_module = module_name[-1][:-5]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name) == 1:
new_module_name = module_name[0]
layer = None
cur_module = None
if imatrix_data is not None and new_module_name in imatrix_data:
cur_imatrix = imatrix_data[new_module_name]
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
# custom mixed quantization strategy
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]) \
or new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int4']
else:
cur_imatrix = None
cur_qtype = qtype

return cur_qtype, cur_imatrix
Loading