Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: support quantized kv cache for Mistral in transformers >=4.36.0 #10326

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,10 +1092,15 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mistral import mistral_attention_forward_4_36
from bigdl.llm.transformers.models.mistral import mistral_model_forward_4_36
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_36
)
convert_forward(model,
module.MistralModel,
mistral_model_forward_4_36
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
Expand Down
20 changes: 13 additions & 7 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
try:
from transformers.cache_utils import Cache
except ImportError:
Cache = Tuple[torch.Tensor]


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -934,11 +938,11 @@ def llama_attention_forward_4_36(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = llama_attention_forward_4_36_quantized
else:
Expand All @@ -960,11 +964,11 @@ def llama_attention_forward_4_36_quantized(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
Expand Down Expand Up @@ -999,8 +1003,10 @@ def llama_attention_forward_4_36_quantized(
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand Down Expand Up @@ -1140,11 +1146,11 @@ def llama_attention_forward_4_36_original(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
Expand Down
253 changes: 248 additions & 5 deletions python/llm/src/bigdl/llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
# limitations under the License.
""" PyTorch Mistral model."""
import math
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import MistralModel
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
Expand All @@ -51,7 +53,10 @@
is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp

try:
from transformers.cache_utils import Cache
except ImportError:
Cache = Tuple[torch.Tensor]
KV_CACHE_ALLOC_BLOCK_LENGTH = 256


Expand Down Expand Up @@ -121,6 +126,37 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
return attn_output, attn_weights


def mistral_model_forward_4_36(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from bigdl.llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return MistralModel.forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)


def mistral_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -480,11 +516,218 @@ def mistral_attention_forward_4_36(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = mistral_attention_forward_4_36_quantized
else:
forward_function = mistral_attention_forward_4_36_original
return forward_function(
self=self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
kwargs=kwargs
)


def mistral_attention_forward_4_36_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
use_fuse_rope,
enough_kv_room,
bsz * q_len)

if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
bsz,
self.num_key_value_heads,
self.head_dim,
0,
1,
dtype=hidden_states.dtype,
device=device
)
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
f"The cache structure has changed since version v4.36. "
"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"mistral")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "mistral")

if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
attn_weights = torch.matmul(query_states.to(key_states.dtype),
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
else:
import linear_q4_0
attn_weights = linear_q4_0.query_key_fp8_matmul(query_states, key_states)

attn_weights = attn_weights / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)

if query_states.size(2) != 1 or query_states.device.type != 'xpu':
attn_output = torch.matmul(attn_weights, value_states)
else:
import linear_q4_0
attn_output = linear_q4_0.attn_value_fp8_matmul(attn_weights,
value_states.transpose(-1, -2))

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output.to(original_dtype), attn_weights, past_key_value


def mistral_attention_forward_4_36_original(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Cache]=None,
output_attentions: bool=False,
use_cache: bool=False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
Expand Down
Loading