From c72250a52ddd3aed758734a3a9ec92f3dc317e37 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Wed, 6 Mar 2024 20:35:31 +0200 Subject: [PATCH 1/8] enable Falcon FP8 inference --- .../maxabs_measure_falcon.json | 10 + examples/text-generation/utils.py | 2 +- .../habana/transformers/generation/utils.py | 2 +- optimum/habana/transformers/modeling_utils.py | 10 +- .../habana/transformers/models/__init__.py | 5 +- .../transformers/models/falcon/__init__.py | 5 +- .../models/falcon/modeling_falcon.py | 941 ++++++++++++++---- 7 files changed, 762 insertions(+), 213 deletions(-) create mode 100644 examples/text-generation/quantization_config/maxabs_measure_falcon.json diff --git a/examples/text-generation/quantization_config/maxabs_measure_falcon.json b/examples/text-generation/quantization_config/maxabs_measure_falcon.json new file mode 100644 index 0000000000..32e9e2209e --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_measure_falcon.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx", + "measure_exclude": "NONE" +} diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 96253f7726..f7a3a6ab65 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -238,7 +238,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = deepspeed.init_inference(model, **ds_inference_kwargs) model = model.module - if model.config.model_type == "llama": + if model.config.model_type == "llama" or "falcon": patch_scoped_linear_all_reduce(model) if args.quant_config: diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 755dec4516..b5ec87175c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -725,7 +725,7 @@ def generate( ) model_kwargs["kv_cache_len"] = calculated_max_length - if self.config.model_type in ["llama"]: + if self.config.model_type in ["llama", "falcon"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index bab0f650f3..c471577969 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -21,7 +21,10 @@ GaudiBloomMLP, GaudiCodeGenAttention, GaudiCodeGenForCausalLM, + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, GaudiGPT2Attention, GaudiGPT2LMHeadModel, @@ -63,9 +66,7 @@ gaudi_conv1d_forward, gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, gaudi_get_extended_attention_mask, gaudi_gpt2_block_forward, gaudi_gpt2_forward, @@ -258,10 +259,11 @@ def adapt_transformers_to_gaudi(): transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward # Optimization for falcon generation on Gaudi + transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention transformers.models.falcon.modeling_falcon.FalconForCausalLM = GaudiFalconForCausalLM + transformers.models.falcon.modeling_falcon.FalconMLP = GaudiFalconMLP transformers.models.falcon.modeling_falcon.FalconModel = GaudiFalconModel - transformers.models.falcon.modeling_falcon.FalconDecoderLayer.forward = gaudi_falcon_decoder_layer_forward - transformers.models.falcon.modeling_falcon.FalconAttention.forward = gaudi_falcon_attention_forward + transformers.models.falcon.modeling_falcon.FalconDecoderLayer = GaudiFalconDecoderLayer transformers.models.falcon.modeling_falcon.FalconAttention._split_heads = gaudi_falcon_attention_split_heads # Optimization for t5 on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 4232534590..a6c14c39ad 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -32,11 +32,12 @@ gaudi_rot_vec_mul, ) from .falcon import ( + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, ) from .gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward from .gpt_bigcode import ( diff --git a/optimum/habana/transformers/models/falcon/__init__.py b/optimum/habana/transformers/models/falcon/__init__.py index 44ac5451f6..00c73ad110 100644 --- a/optimum/habana/transformers/models/falcon/__init__.py +++ b/optimum/habana/transformers/models/falcon/__init__.py @@ -1,7 +1,8 @@ from .modeling_falcon import ( + GaudiFalconAttention, + GaudiFalconDecoderLayer, GaudiFalconForCausalLM, + GaudiFalconMLP, GaudiFalconModel, - gaudi_falcon_attention_forward, gaudi_falcon_attention_split_heads, - gaudi_falcon_decoder_layer_forward, ) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 9c853dfb2a..3322711e40 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -2,6 +2,7 @@ import math import warnings from typing import Optional, Tuple, Union +import os import torch @@ -27,6 +28,7 @@ import habana_frameworks.torch.core as htcore +from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa @@ -34,13 +36,19 @@ BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) +from transformers.models.falcon.configuration_falcon import FalconConfig from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconDecoderLayer, FalconForCausalLM, + FalconMLP, + FalconLinear, FalconModel, + FalconRotaryEmbedding, apply_rotary_pos_emb, build_alibi_tensor, - dropout_add, ) +from ..modeling_all_models import ScopedLinearAllReduce from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -52,6 +60,25 @@ logger = logging.get_logger(__name__) +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + residual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out.add_(residual) + return out + + def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: # TODO: remove `.clone()` when SynapseAI v1.15 is released @@ -111,257 +138,721 @@ def gaudi_falcon_attention_split_heads( return query, key, value -def gaudi_falcon_attention_forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -): - """ - Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - replace F.scaled_dot_product_attention with Habana torch's version - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention +class ScaledDotProductAttention(nn.Module): + def __init__(self, config: FalconConfig): + super().__init__() + self.head_dim = config.hidden_size // config.num_attention_heads + self.bmm1 = Matmul() + self.bmm2 = Matmul() + self.softmax = Softmax() + + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(self.head_dim) + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + + attn_weight += attn_mask + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight, value) + + +def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + prev_cast = prev.to(orig_cur.dtype) + return prev_cast + else: + return torch.cat((prev, cur), dim=dim) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, query_length, _, _ = query_layer.shape +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 - query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + def update(self, prev, cur, dim, idx, inp_seq_len): + return update(prev, cur, dim, idx, inp_seq_len) - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - if token_idx is not None: - # When token_idx is used, - # past_kv_length = 0 - # static seq len = (input token len + max output token len) - kv_seq_len = layer_past[0].shape[-2] + +class GaudiFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + super().__init__(config) + + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim else: - kv_seq_len += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None: - past_key, past_value = layer_past - if token_idx is not None: - past_key.index_copy_(-2, token_idx - 1, key_layer) - past_value.index_copy_(-2, token_idx - 1, value_layer) - key_layer = past_key - value_layer = past_value + qkv_out_dim = 3 * self.hidden_size + + if os.getenv("QUANT_CONFIG", ""): + self.sdpa = ScaledDotProductAttention(config) + + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.max_position_embeddings = config.max_position_embeddings + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + if self.config.new_decoder_architecture: + cache_shape = (batch_size, self.num_heads, max_seq_len, self.head_dim) else: - # concatenate along seq_length dimension: - # - key: [batch_size, self.num_heads, kv_length, head_dim] - # - value: [batch_size, self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=-2) - value_layer = torch.cat((past_value, value_layer), dim=-2) - - kv_length = key_layer.shape[-2] - if use_cache: - present = (key_layer, value_layer) - else: - present = None + cache_shape = (batch_size, 1, max_seq_len, self.head_dim) + device = self.query_key_value.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + self.rotary_emb._set_cos_sin_cache( + seq_len, self.query_key_value.weight.device, self.query_key_value.weight.dtype + ) - if alibi is None: - if output_attentions: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ): + """ + Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - replace F.scaled_dot_product_attention with Habana torch's version + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). - attn_output = attention_scores @ value_layer + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + if token_idx is not None: + if reuse_cache: + kv_seq_len = layer_past[0][-2] # layer_past conveys only shapes without kv tensors + else: + kv_seq_len = layer_past[0].shape[-2] + else: + kv_length += layer_past[0].shape[-2] + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) + + if layer_past is not None or reuse_cache: + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + else: + key_layer = update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_layer.shape[-2] + + kv_length = key_layer.shape[-2] + if use_cache: + if reuse_cache: + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + present = (key_layer, value_layer) else: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( + present = None + + if alibi is None: + if output_attentions: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + attn_output = attention_scores @ value_layer + else: + if FusedSDPA: + if os.getenv("QUANT_CONFIG", ""): + attn_output = self.sdpa( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, 0.0, # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) + # Performance improvement for HPU + if self.training is True and htcore: + htcore.mark_step() + attention_scores = None + + attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, -1) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_scores else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - # Performance improvement for HPU - if self.training is True and htcore: - htcore.mark_step() - attention_scores = None + return attn_output, present + + else: + if self._use_sdpa and not output_attentions and head_mask is None: + if FusedSDPA: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + self.attention_dropout.p if self.training else 0.0, + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, -1) + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - attn_output = self.dense(attn_output) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - if output_attentions: - return attn_output, present, attention_scores + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) + + attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ): + """ + Copied from FalconAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - replace F.scaled_dot_product_attention with Habana torch's version + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) + + kv_seq_len = key_layer.shape[-2] + if layer_past is not None: + if token_idx is not None: + if reuse_cache: + kv_seq_len = layer_past[0][-2] + else: + kv_seq_len = layer_past[0].shape[-2] + else: + kv_seq_len += layer_past[0].shape[-2] + + if alibi is None: + cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) + + if layer_past is not None or reuse_cache: + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + else: + key_layer = update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_layer.shape[-2] + + kv_length = key_layer.shape[-2] + if use_cache: + if reuse_cache: + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + present = (key_layer, value_layer) else: - return attn_output, present + present = None - else: - if self._use_sdpa and not output_attentions and head_mask is None: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( + if alibi is None: + if output_attentions: + attention_scores = query_layer @ key_layer.transpose(-1, -2) + attention_scores /= math.sqrt(self.head_dim) + + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) + # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). + attn_output = attention_scores @ value_layer + else: + if FusedSDPA: + if os.getenv("QUANT_CONFIG", ""): + attn_output = self.sdpa( + query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False + ) + + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + # Workaround util scaled_dot_product_attention support broadcast. + if self.training is True and query_layer.shape != key_layer.shape: + key_layer = torch.broadcast_to(key_layer, query_layer.shape) + value_layer = torch.broadcast_to(value_layer, query_layer.shape) + attn_output = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask, - self.attention_dropout.p if self.training else 0.0, - self.is_causal and attention_mask is None and query_length > 1, + 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, ) - else: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + # Performance improvement for HPU + if self.training is True and htcore: + htcore.mark_step() + attention_scores = None + + attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, -1) attn_output = self.dense(attn_output) + + if output_attentions: + return attn_output, present, attention_scores + else: + return attn_output, present + else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) + if self._use_sdpa and not output_attentions and head_mask is None: + if FusedSDPA: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_layer, + key_layer, + value_layer, + attention_mask, + self.attention_dropout.p if self.training else 0.0, + self.is_causal and attention_mask is None and query_length > 1, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.attention_dropout.p if self.training else 0.0, + is_causal=self.is_causal and attention_mask is None and query_length > 1, + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + attn_output = self.dense(attn_output) + else: + matmul_result = query_layer @ key_layer.transpose(-1, -2) - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) - if head_mask is not None: - attention_probs = attention_probs * head_mask + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + if head_mask is not None: + attention_probs = attention_probs * head_mask - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - attn_output = self.dense(attn_output) + # change view [batch_size, q_length, num_heads * head_dim] + attn_output = self._merge_heads(attn_output) - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - - -def gaudi_falcon_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -): - """ - Copied from FalconDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - add token_idx and position_ids into attention inputs - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + attn_output = self.dense(attn_output) - residual = hidden_states + if output_attentions: + return attn_output, present, attention_probs + else: + return attn_output, present - if self.config.new_decoder_architecture: - attention_layernorm_out = self.ln_attn(hidden_states) - mlp_layernorm_out = self.ln_mlp(hidden_states) - else: - attention_layernorm_out = self.input_layernorm(hidden_states) - - # Self attention. - attn_outputs = self.self_attention( - attention_layernorm_out, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, + def attention_all_reduce(self, attn_output): + if hasattr(self.dense, "all_reduce"): + self.dense.all_reduce(attn_output) + + def post_attn_forward(self, attn_output): + if hasattr(self.dense, "all_reduce"): + self.dense.post_all_reduce(attn_output) + return attn_output + + +class GaudiFalconMLP(FalconMLP): + def pre_mlp_forward(self, x): + x = self.act(self.dense_h_to_4h(x)) + x = self.dense_4h_to_h(x) + return x + + def mlp_all_reduce(self, x): + if hasattr(self.dense_4h_to_h, "all_reduce"): + self.dense_4h_to_h.all_reduce(x) + + def post_mlp_forward(self, x): + if hasattr(self.dense_4h_to_h, "all_reduce"): + self.dense_4h_to_h.post_all_reduce(x) + return x + + +class GaudiFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + super().__init__(config) + self.self_attention = GaudiFalconAttention(config) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attention.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def update_sincos_cache(self, seq_len): + self.self_attention.update_sincos_cache(seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, **kwargs, - ) + ): + """ + Copied from FalconDecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py + The only differences are: + - add new args token_idx and position_ids + - add token_idx and position_ids into attention inputs + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + if not self.config.new_decoder_architecture: + residual = hidden_states + + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + **kwargs, + ) + + attention_output = attn_outputs[0] - attention_output = attn_outputs[0] + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) - if not self.config.new_decoder_architecture: - if self.config.parallel_attn: - mlp_layernorm_out = attention_layernorm_out + outputs = attn_outputs[1:] else: - residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training) - mlp_layernorm_out = self.post_attention_layernorm(residual) + residual = hidden_states + hidden_states, present, attn_scores, mlp_layernorm_out = ( + self.pre_attn( # layernorm+attention before AllReduce + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + **kwargs, + ) + ) - outputs = attn_outputs[1:] + self.self_attention.attention_all_reduce(hidden_states) + hidden_states = self.self_attention.post_attn_forward( + hidden_states + ) - # MLP. - mlp_output = self.mlp(mlp_layernorm_out) + attention_output = hidden_states - if self.config.new_decoder_architecture or self.config.parallel_attn: - mlp_output += attention_output + outputs = (present, attn_scores) - output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + # MLP + if not self.config.new_decoder_architecture: + hidden_states = self.mlp(mlp_layernorm_out) + else: + hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.mlp.post_mlp_forward(hidden_states) - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] + if self.config.new_decoder_architecture or self.config.parallel_attn: + hidden_states += attention_output - return outputs # hidden_states, present, attentions + output = dropout_add(hidden_states, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + def pre_attn( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, + ): + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + mlp_layernorm_out = None + + # Self attention. + attn_scores = None + if output_attentions: + attn_outputs, present, attn_scores = self.self_attention.pre_attn_forward( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + ) + else: + attn_outputs, present = self.self_attention.pre_attn_forward( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + ) + + return attn_outputs, present, attn_scores, mlp_layernorm_out class GaudiFalconModel(FalconModel): @@ -375,6 +866,14 @@ class GaudiFalconModel(FalconModel): - use old version of _make_causal_mask to workaround toch.triu that is not supported in Synapse """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.h: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def update_sincos_cache(self, seq_len): + for layer in self.h: + layer.update_sincos_cache(seq_len) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -388,6 +887,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -425,8 +926,11 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 - if past_key_values[0] is not None and token_idx is None: - past_key_values_length = past_key_values[0][0].shape[-2] + if past_key_values[0] is not None and token_idx is None: ### non static input + if reuse_cache: + past_key_values_length = past_key_values[0][0][-2] + else: + past_key_values_length = past_key_values[0][0].shape[-2] if self.use_alibi: mask = ( @@ -489,6 +993,7 @@ def forward( attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + else: # 4d mask is passed through the layers attention_mask = _gaudi_prepare_4d_causal_attention_mask( @@ -501,6 +1006,7 @@ def forward( # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + htcore.mark_step() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -529,6 +1035,8 @@ def forward( output_attentions=output_attentions, alibi=alibi, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -563,8 +1071,16 @@ class GaudiFalconForCausalLM(FalconForCausalLM): - add token_idx and position_ids into model inputs - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + - add new args reuse_cache """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.transformer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + + def update_sincos_cache(self, seq_len): + self.transformer.update_sincos_cache(seq_len) + def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, @@ -574,6 +1090,7 @@ def prepare_inputs_for_generation( token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: + reuse_cache = kwargs.get("reuse_cache") if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -588,6 +1105,10 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if ( @@ -612,6 +1133,8 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "cache_idx": kwargs.get("cache_idx"), } def forward( @@ -628,6 +1151,9 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + trim_logits: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -649,9 +1175,18 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = transformer_outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1:, :] + lm_logits = self.lm_head(hidden_states) loss = None From 63fd6b2b2c77ef6cd8f86062dd760814bfa4d1f3 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Fri, 8 Mar 2024 01:53:15 +0200 Subject: [PATCH 2/8] added example command in readme, code cleanup --- examples/text-generation/README.md | 35 ++- .../maxabs_measure_falcon.json | 10 - .../models/falcon/modeling_falcon.py | 273 ++---------------- 3 files changed, 58 insertions(+), 260 deletions(-) delete mode 100644 examples/text-generation/quantization_config/maxabs_measure_falcon.json diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 138a834599..9ca3e396ab 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -240,7 +240,7 @@ While `--bucket_size` works for any model without model file changes, an even mo ### Running with FP8 -Llama2-70b, Llama2-7b and Mixtral-8x7B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. +Llama2-70b, Llama2-7b, Mixtral-8x7B, Falcon-7B, Falcon-40B, and Falcon-180B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html @@ -320,6 +320,39 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati --bf16 \ --fp8 ``` + +Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: +```bash +QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ +--use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path tiiuae/falcon-180B \ +--use_hpu_graphs \ +--use_kv_cache \ +--limit_hpu_graphs \ +--max_input_tokens 128 \ +--max_new_tokens 128 \ +--batch_size 1 \ +--bf16 \ +--reuse_cache \ +--trim_logits +``` + +Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ +--use_deepspeed --world_size 8 run_generation.py \ +--model_name_or_path tiiuae/falcon-180B \ +--use_hpu_graphs \ +--use_kv_cache \ +--limit_hpu_graphs \ +--max_input_tokens 128 \ +--max_new_tokens 2048 \ +--batch_size 110 \ +--bf16 \ +--reuse_cache \ +--trim_logits \ +--fp8 +``` `--fp8` is required to enable quantization in fp8. ### Using Habana Flash Attention diff --git a/examples/text-generation/quantization_config/maxabs_measure_falcon.json b/examples/text-generation/quantization_config/maxabs_measure_falcon.json deleted file mode 100644 index 32e9e2209e..0000000000 --- a/examples/text-generation/quantization_config/maxabs_measure_falcon.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "method": "HOOKS", - "mode": "MEASURE", - "observer": "maxabs", - "whitelist": {"types": [], "names": []}, - "blacklist": {"types": [], "names": []}, - "dump_stats_path": "./hqt_output/measure", - "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx", - "measure_exclude": "NONE" -} diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 3322711e40..1db2a61022 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -62,17 +62,8 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ - Dropout add function - - Args: - x (`torch.tensor`, *required*): - input tensor - residual (`torch.tensor`, *required*): - residual tensor - prob (`float`, *required*): - dropout probability - training (`bool`, *required*): - training mode + Copied from transformers.models.falcon.modeling_falcon/dropout_add + https://github.com/huggingface/transformers/blob/b338a6c3b8eda29610d4d472cad8cd87cbfdaaed/src/transformers/models/falcon/modeling_falcon.py#L248 """ out = F.dropout(x, p=prob, training=training) out.add_(residual) @@ -81,7 +72,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: - # TODO: remove `.clone()` when SynapseAI v1.15 is released + # TODO: remove `.clone()` once the problem is fixed in SynapseAI return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ), FusedRoPE.apply( @@ -274,197 +265,6 @@ def update_sincos_cache(self, seq_len): seq_len, self.query_key_value.weight.device, self.query_key_value.weight.dtype ) - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - token_idx: Optional[torch.Tensor] = None, - reuse_cache: Optional[bool] = False, - cache_idx: int = None, - **kwargs, - ): - """ - Copied from FalconAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py - The only differences are: - - add new args token_idx and position_ids - - replace F.scaled_dot_product_attention with Habana torch's version - - add new args reuse_cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, query_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - key_layer = key_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size, -1, query_length, self.head_dim) - - kv_seq_len = key_layer.shape[-2] - if layer_past is not None: - if token_idx is not None: - if reuse_cache: - kv_seq_len = layer_past[0][-2] # layer_past conveys only shapes without kv tensors - else: - kv_seq_len = layer_past[0].shape[-2] - else: - kv_length += layer_past[0].shape[-2] - if alibi is None: - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - - if layer_past is not None or reuse_cache: - if reuse_cache: - key_layer = self.k_cache(key_layer, -2, token_idx) - value_layer = self.v_cache(value_layer, -2, token_idx) - else: - key_layer = update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - - if cache_idx is not None and query_length == 1: - key_layer = key_layer[:, :, :cache_idx, :] - value_layer = value_layer[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_layer.shape[-2] - - kv_length = key_layer.shape[-2] - if use_cache: - if reuse_cache: - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - present = (key_layer, value_layer) - else: - present = None - - if alibi is None: - if output_attentions: - attention_scores = query_layer @ key_layer.transpose(-1, -2) - attention_scores /= math.sqrt(self.head_dim) - - attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) - attn_output = attention_scores @ value_layer - else: - if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - attn_output = self.sdpa( - query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False - ) - else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - self.is_causal and attention_mask is None and query_length > 1, - ) - else: - # Workaround util scaled_dot_product_attention support broadcast. - if self.training is True and query_layer.shape != key_layer.shape: - key_layer = torch.broadcast_to(key_layer, query_layer.shape) - value_layer = torch.broadcast_to(value_layer, query_layer.shape) - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - # Performance improvement for HPU - if self.training is True and htcore: - htcore.mark_step() - attention_scores = None - - attn_output = attn_output.view(batch_size, -1, query_length, self.head_dim) - attn_output = attn_output.permute(0, 2, 1, 3) - attn_output = attn_output.reshape(batch_size, query_length, -1) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_scores - else: - return attn_output, present - - else: - if self._use_sdpa and not output_attentions and head_mask is None: - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_layer, - key_layer, - value_layer, - attention_mask, - self.attention_dropout.p if self.training else 0.0, - self.is_causal and attention_mask is None and query_length > 1, - ) - else: - attn_output = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) - - attn_output = self.dense(attn_output) - else: - matmul_result = query_layer @ key_layer.transpose(-1, -2) - - # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16 or input_dtype == torch.bfloat16: - attention_scores = attention_scores.to(torch.float32) - - attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) - attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) - # [batch_size, num_heads, q_length, kv_length] - attention_probs = self.attention_dropout(attention_probs) - - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # change view [batch_size, num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) - - # matmul: [batch_size * num_heads, q_length, head_dim] - attn_output = (attention_probs_reshaped @ value_layer).flatten(0, 1) - - # change view [batch_size, q_length, num_heads * head_dim] - attn_output = self._merge_heads(attn_output) - - attn_output = self.dense(attn_output) - - if output_attentions: - return attn_output, present, attention_probs - else: - return attn_output, present - def pre_attn_forward( self, hidden_states: torch.Tensor, @@ -485,6 +285,7 @@ def pre_attn_forward( The only differences are: - add new args token_idx and position_ids - replace F.scaled_dot_product_attention with Habana torch's version + - add new arg reuse_cache """ if "padding_mask" in kwargs: warnings.warn( @@ -721,14 +522,10 @@ def forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - if not self.config.new_decoder_architecture: - residual = hidden_states - - attention_layernorm_out = self.input_layernorm(hidden_states) - - # Self attention. - attn_outputs = self.self_attention( - attention_layernorm_out, + residual = hidden_states + hidden_states, present, attn_scores, attention_layernorm_out, mlp_layernorm_out = ( + self.pre_attn( # layernorm + attention before AllReduce + hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, @@ -738,11 +535,19 @@ def forward( output_attentions=output_attentions, token_idx=token_idx, reuse_cache=reuse_cache, + cache_idx=cache_idx, **kwargs, ) + ) + + self.self_attention.attention_all_reduce(hidden_states) + hidden_states = self.self_attention.post_attn_forward( + hidden_states + ) - attention_output = attn_outputs[0] + attention_output = hidden_states + if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out else: @@ -751,42 +556,11 @@ def forward( ) mlp_layernorm_out = self.post_attention_layernorm(residual) - outputs = attn_outputs[1:] - else: - residual = hidden_states - hidden_states, present, attn_scores, mlp_layernorm_out = ( - self.pre_attn( # layernorm+attention before AllReduce - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - position_ids=position_ids, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - token_idx=token_idx, - reuse_cache=reuse_cache, - cache_idx=cache_idx, - **kwargs, - ) - ) + outputs = (present, attn_scores) - self.self_attention.attention_all_reduce(hidden_states) - hidden_states = self.self_attention.post_attn_forward( - hidden_states - ) - - attention_output = hidden_states - - outputs = (present, attn_scores) - - # MLP - if not self.config.new_decoder_architecture: - hidden_states = self.mlp(mlp_layernorm_out) - else: - hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) - self.mlp.mlp_all_reduce(hidden_states) - hidden_states = self.mlp.post_mlp_forward(hidden_states) + hidden_states = self.mlp.pre_mlp_forward(mlp_layernorm_out) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.mlp.post_mlp_forward(hidden_states) if self.config.new_decoder_architecture or self.config.parallel_attn: hidden_states += attention_output @@ -852,7 +626,7 @@ def pre_attn( cache_idx=cache_idx, ) - return attn_outputs, present, attn_scores, mlp_layernorm_out + return attn_outputs, present, attn_scores, attention_layernorm_out, mlp_layernorm_out class GaudiFalconModel(FalconModel): @@ -864,6 +638,7 @@ class GaudiFalconModel(FalconModel): - set past_key_values_length=0 when token_idx is used (with static input shape) - add new arg tgt_len to _expand_mask because past_key_values_length is no longer valid with token_idx - use old version of _make_causal_mask to workaround toch.triu that is not supported in Synapse + - add new arg reuse_cache """ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -926,7 +701,7 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation past_key_values_length = 0 - if past_key_values[0] is not None and token_idx is None: ### non static input + if past_key_values[0] is not None and token_idx is None: if reuse_cache: past_key_values_length = past_key_values[0][0][-2] else: From af025a79ad340b0875e77dff8a4b557666d4fade Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Fri, 8 Mar 2024 21:45:56 +0200 Subject: [PATCH 3/8] resolve issues in finetuning --- examples/text-generation/README.md | 10 ++++---- .../models/falcon/modeling_falcon.py | 24 +++++++------------ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 9ca3e396ab..ed987694e7 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -324,17 +324,15 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ ---use_deepspeed --world_size 8 run_generation.py \ +--use_deepspeed --world_size 8 run_lm_eval.py \ +-o acc_falcon180b_bs1_quant.txt \ --model_name_or_path tiiuae/falcon-180B \ --use_hpu_graphs \ --use_kv_cache \ ---limit_hpu_graphs \ ---max_input_tokens 128 \ ---max_new_tokens 128 \ +--trim_logits \ --batch_size 1 \ --bf16 \ ---reuse_cache \ ---trim_logits +--reuse_cache ``` Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 1db2a61022..ed49a0ca99 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -1,8 +1,8 @@ import contextlib import math +import os import warnings from typing import Optional, Tuple, Union -import os import torch @@ -42,13 +42,10 @@ FalconDecoderLayer, FalconForCausalLM, FalconMLP, - FalconLinear, FalconModel, - FalconRotaryEmbedding, apply_rotary_pos_emb, build_alibi_tensor, ) -from ..modeling_all_models import ScopedLinearAllReduce from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -66,7 +63,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: https://github.com/huggingface/transformers/blob/b338a6c3b8eda29610d4d472cad8cd87cbfdaaed/src/transformers/models/falcon/modeling_falcon.py#L248 """ out = F.dropout(x, p=prob, training=training) - out.add_(residual) + if training: + out = residual + out + else: + out.add_(residual) return out @@ -161,6 +161,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if is_causal: assert attn_mask is None + attn_bias = torch.zeros(L, S, dtype=query.dtype) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -230,13 +231,6 @@ class GaudiFalconAttention(FalconAttention): def __init__(self, config: FalconConfig): super().__init__(config) - if config.new_decoder_architecture: - qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim - elif config.multi_query: - qkv_out_dim = self.hidden_size + 2 * self.head_dim - else: - qkv_out_dim = 3 * self.hidden_size - if os.getenv("QUANT_CONFIG", ""): self.sdpa = ScaledDotProductAttention(config) @@ -285,7 +279,7 @@ def pre_attn_forward( The only differences are: - add new args token_idx and position_ids - replace F.scaled_dot_product_attention with Habana torch's version - - add new arg reuse_cache + - add new arg reuse_cache """ if "padding_mask" in kwargs: warnings.warn( @@ -541,9 +535,7 @@ def forward( ) self.self_attention.attention_all_reduce(hidden_states) - hidden_states = self.self_attention.post_attn_forward( - hidden_states - ) + hidden_states = self.self_attention.post_attn_forward(hidden_states) attention_output = hidden_states From 7b9852a7f3d46467c846dfa1271f8d6402835fc4 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Tue, 12 Mar 2024 01:13:31 +0200 Subject: [PATCH 4/8] enable non reuse cache flow for fp8 --- examples/text-generation/README.md | 1 + .../models/falcon/modeling_falcon.py | 70 ++++++++++--------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index ed987694e7..2a5db4c926 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -322,6 +322,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati ``` Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards: +> Please note that Falcon-180B is a gated model, and users are required to request access to it. Please refer to the instructions provided in the StarCoder example above. ```bash QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \ --use_deepspeed --world_size 8 run_lm_eval.py \ diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index ed49a0ca99..0ce69ce606 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -178,27 +178,6 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return self.bmm2(attn_weight, value) -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - cur = cur.to(dtype=prev.dtype) - - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - - if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -224,7 +203,23 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) def update(self, prev, cur, dim, idx, inp_seq_len): - return update(prev, cur, dim, idx, inp_seq_len) + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[-2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) class GaudiFalconAttention(FalconAttention): @@ -310,31 +305,39 @@ def pre_attn_forward( cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) - if layer_past is not None or reuse_cache: + if use_cache: if reuse_cache: key_layer = self.k_cache(key_layer, -2, token_idx) value_layer = self.v_cache(value_layer, -2, token_idx) + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_layer = update( + if layer_past is None: + past_key = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + past_value = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + layer_past = (past_key, past_value) + key_layer = self.k_cache.update( layer_past[0], key_layer, -2, token_idx, self.inp_seq_len ) # k_layer bs*1, q_len, head_dim - value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = layer_past if cache_idx is not None and query_length == 1: key_layer = key_layer[:, :, :cache_idx, :] value_layer = value_layer[:, :, :cache_idx, :] attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_layer.shape[-2] - - kv_length = key_layer.shape[-2] - if use_cache: - if reuse_cache: - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - present = (key_layer, value_layer) else: present = None + kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] + if alibi is None: if output_attentions: attention_scores = query_layer @ key_layer.transpose(-1, -2) @@ -349,7 +352,6 @@ def pre_attn_forward( attn_output = self.sdpa( query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False ) - else: with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( From 8a18736e9dc614865906c4fa38c4abd4f28e05c9 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Wed, 13 Mar 2024 02:15:18 +0200 Subject: [PATCH 5/8] revert non reuse_cache flow for training due to perf drop --- .../models/falcon/modeling_falcon.py | 105 ++++++++++-------- 1 file changed, 61 insertions(+), 44 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 0ce69ce606..a329ec1ac0 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -178,6 +178,27 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return self.bmm2(attn_weight, value) +def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + cur = cur.to(dtype=prev.dtype) + + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + + if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + prev_cast = prev.to(orig_cur.dtype) + return prev_cast + else: + return torch.cat((prev, cur), dim=dim) + + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -203,23 +224,7 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - cur = cur.to(dtype=prev.dtype) - - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - - if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[-2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) + return update(prev, cur, dim, idx, inp_seq_len) class GaudiFalconAttention(FalconAttention): @@ -306,37 +311,49 @@ def pre_attn_forward( query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) if use_cache: - if reuse_cache: - key_layer = self.k_cache(key_layer, -2, token_idx) - value_layer = self.v_cache(value_layer, -2, token_idx) - present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + if self.training: + if layer_past is not None: + key_layer = update(layer_past[0], key_layer, -2, token_idx, self.inp_seq_len) + value_layer = update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = (key_layer, value_layer) + else: + present = None + else: - if layer_past is None: - past_key = torch.zeros( - key_layer.shape, - dtype=self.query_key_value.weight.dtype, - device=self.query_key_value.weight.device, - ) - past_value = torch.zeros( - key_layer.shape, - dtype=self.query_key_value.weight.dtype, - device=self.query_key_value.weight.device, - ) - layer_past = (past_key, past_value) - key_layer = self.k_cache.update( - layer_past[0], key_layer, -2, token_idx, self.inp_seq_len - ) # k_layer bs*1, q_len, head_dim - value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) - present = layer_past - - if cache_idx is not None and query_length == 1: - key_layer = key_layer[:, :, :cache_idx, :] - value_layer = value_layer[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] + if reuse_cache: + key_layer = self.k_cache(key_layer, -2, token_idx) + value_layer = self.v_cache(value_layer, -2, token_idx) + present = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if layer_past is None: + past_key = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + past_value = torch.zeros( + key_layer.shape, + dtype=self.query_key_value.weight.dtype, + device=self.query_key_value.weight.device, + ) + layer_past = (past_key, past_value) + key_layer = self.k_cache.update( + layer_past[0], key_layer, -2, token_idx, self.inp_seq_len + ) # k_layer bs*1, q_len, head_dim + value_layer = self.v_cache.update(layer_past[1], value_layer, -2, token_idx, self.inp_seq_len) + present = layer_past + + if cache_idx is not None and query_length == 1: + key_layer = key_layer[:, :, :cache_idx, :] + value_layer = value_layer[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] else: present = None - kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] + if self.training and layer_past is None: + kv_length = key_layer.shape[-2] + else: + kv_length = present[0][-2] if reuse_cache else present[0].shape[-2] if alibi is None: if output_attentions: From 3feaed0faa9f82abbd3d0729c35ac5b7723ee7f7 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Thu, 14 Mar 2024 01:32:31 +0200 Subject: [PATCH 6/8] add falcon180B FP8 test --- tests/test_text_generation_example.py | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index ff6f94d002..e1a40bf97e 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -7,7 +7,7 @@ import pytest -from .test_examples import TIME_PERF_FACTOR +from test_examples import TIME_PERF_FACTOR if os.environ.get("GAUDI2_CI", "0") == "1": @@ -26,6 +26,9 @@ ("mistralai/Mistral-7B-v0.1", 125.26115369093216), ("mistralai/Mixtral-8x7B-v0.1", 23.78652574031883), ], + "fp8": [ + ("tiiuae/falcon-180B", 52.525947696914784), + ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), ("meta-llama/Llama-2-70b-hf", 61.973950428647164), @@ -69,6 +72,7 @@ def _test_text_generation( deepspeed: bool = False, world_size: int = 8, torch_compile: bool = False, + fp8: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -103,6 +107,13 @@ def _test_text_generation( if not deepspeed: command.append("--bf16") + if fp8: + command += [ + "--fp8", + "--reuse_cache", + "--trim_logits", + ] + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") print(f"\n\nCommand to test: {' '.join(command)}\n") @@ -112,6 +123,15 @@ def _test_text_generation( pattern = re.compile(r"([\"\'].+?[\"\'])|\s") command = [x for y in command for x in re.split(pattern, y) if x] + if fp8: + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" + ) + subprocess.run(command) + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" + ) + proc = subprocess.run(command) # Ensure the run finished without any issue @@ -135,6 +155,13 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str): _test_text_generation(model_name, baseline, token) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["fp8"]) +def test_text_generation_fp8(model_name: str, baseline: float, token: str): + deepspeed = True if "falcon-180B" in model_name else False + world_size = 8 if "falcon-180B" in model_name else None + _test_text_generation(model_name, baseline, token, deepspeed=deepspeed, world_size=world_size, fp8=True) + + @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["deepspeed"]) def test_text_generation_deepspeed(model_name: str, baseline: float, token: str): world_size = 2 if "opt-66b" in model_name else 8 From 1b90b33ff50932679f91d26b2e90338c71694976 Mon Sep 17 00:00:00 2001 From: Local Lab User Date: Thu, 14 Mar 2024 03:06:10 +0200 Subject: [PATCH 7/8] fix error --- tests/test_text_generation_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index e1a40bf97e..00602dbd0e 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -7,7 +7,7 @@ import pytest -from test_examples import TIME_PERF_FACTOR +from .test_examples import TIME_PERF_FACTOR if os.environ.get("GAUDI2_CI", "0") == "1": @@ -27,7 +27,7 @@ ("mistralai/Mixtral-8x7B-v0.1", 23.78652574031883), ], "fp8": [ - ("tiiuae/falcon-180B", 52.525947696914784), + ("tiiuae/falcon-180B", 47.67900945905787), ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), From 5a2128b52ec6108c525ab266a542504d50cd7b69 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Sat, 16 Mar 2024 07:15:50 +0000 Subject: [PATCH 8/8] fix run_lm_eval.py to save --reuse_cache --- examples/text-generation/run_lm_eval.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4f90306354..8d61118890 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,10 +75,15 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type == "llama": + if self.model.config.model_type == "llama" or "falcon": self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, + } + ) + if self.model.config.model_type == "llama": + self.model_inputs.update( + { "attn_softmax_bf16": self.options.attn_softmax_bf16, } ) @@ -131,11 +136,7 @@ def _model_call(self, inps): if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) if self.options.use_cache and self.options.reuse_cache: - self.model.allocate_kv_cache( - bs, - bucket_length + 1, - bucket_length - ) + self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -177,6 +178,7 @@ def main(): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path)