From 6419b041eceefdd4db7eb95094875edb4f99481c Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Sun, 31 Dec 2023 20:24:31 -0700 Subject: [PATCH 1/9] initial implementation of flash attention for gptj --- src/transformers/models/gptj/modeling_gptj.py | 308 ++++++++++++++++-- 1 file changed, 287 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index e3034eecaf04..c3d5d6df54d3 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -19,6 +19,7 @@ import torch import torch.fx +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -35,6 +36,8 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, is_torch_fx_proxy, logging, ) @@ -42,6 +45,11 @@ from .configuration_gptj import GPTJConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" @@ -55,6 +63,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() @@ -82,7 +103,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten class GPTJAttention(nn.Module): def __init__(self, config): super().__init__() - + self.config = config max_positions = config.max_position_embeddings self.register_buffer( "bias", @@ -96,6 +117,8 @@ def __init__(self, config): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + self.embed_dim = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_attention_heads @@ -269,6 +292,241 @@ def forward( return outputs # a, present, (attentions) +class GPTJFlashAttention2(GPTJAttention): + """ + GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.FloatTensor, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=2) + + if use_cache is True: + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) + else: + present = None + + # The Falsh attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we need to keep the original shape for query and key, and reshape value + # to have the correct shape. + value = value.permute(0, 2, 1, 3).contiguous() + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj + + query_length = query.shape[1] + + # Compute attention + attn_weights = self._flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attention_dropout, + ) + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3] + ) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class GPTJMLP(nn.Module): def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim super().__init__() @@ -293,7 +551,11 @@ def __init__(self, config): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config) + self.attn = ( + GPTJFlashAttention2(config) + if config._attn_implementation == "flash_attention_2" + else GPTJAttention(config) + ) self.mlp = GPTJMLP(inner_dim, config) def forward( @@ -343,6 +605,7 @@ class GPTJPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -496,6 +759,8 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): warnings.warn( @@ -600,25 +865,26 @@ def forward( position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + if not self._use_flash_attention_2: + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head From 3a9e31f9b5e541f6bab03fac8879418770c3a7f3 Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Mon, 1 Jan 2024 19:26:58 -0700 Subject: [PATCH 2/9] modify flash attention and overwrite test_flash_attn_2_generate_padding_right --- src/transformers/models/gptj/modeling_gptj.py | 13 ++++- tests/models/gptj/test_modeling_gptj.py | 53 ++++++++++++++++++- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index c3d5d6df54d3..2c0ba30a37b8 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -355,11 +355,18 @@ def forward( key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) + # tanspose to have the desired shape + # before transpose: batch_size x seq_length x num_attention_heads x head_dim + # after transpose: batch_size x num_attention_heads x seq_length x head_dim + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + # value: batch_size x num_attention_heads x seq_length x head_dim + if layer_past is not None: past_key = layer_past[0] past_value = layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=2) + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) if use_cache is True: # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. @@ -372,6 +379,8 @@ def forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we need to keep the original shape for query and key, and reshape value # to have the correct shape. + key = key.permute(0, 2, 1, 3).contiguous() + query = query.permute(0, 2, 1, 3).contiguous() value = value.permute(0, 2, 1, 3).contiguous() # In PEFT, usually we cast the layer norms in float32 for training stability reasons diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index 42ded9c81ae0..f9699f3dd600 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -17,8 +17,17 @@ import datetime import unittest +import pytest + from transformers import GPTJConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, tooslow, torch_device +from transformers.testing_utils import ( + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + tooslow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -518,6 +527,48 @@ def test_model_from_pretrained(self): model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16) self.assertIsNotNone(model) + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = GPTJForCausalLM.from_pretrained( + "EleutherAI/gpt-j-6b", + load_in_4bit=True, + device_map={"": 0}, + revision="float16", + torch_dtype=torch.float16, + ) + + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = GPTJForCausalLM.from_pretrained( + "EleutherAI/gpt-j-6b", + load_in_4bit=True, + device_map={"": 0}, + attn_implementation="flash_attention_2", + revision="float16", + torch_dtype=torch.float16, + ) + + output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_2 = tokenizer.batch_decode(output_fa_2) + + self.assertListEqual(output_native, output_fa_2) + @require_torch class GPTJModelLanguageGenerationTest(unittest.TestCase): From cb656559a013f2ce526297811f11122dcddeff4f Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:14:57 -0700 Subject: [PATCH 3/9] update flash attention support list --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 5cc9cd208d8a..fff57c083fb7 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -43,6 +43,7 @@ FlashAttention-2 is currently supported for the following architectures: * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) +* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) From e47ef1330ea20fd8b7cefc90bb9c98bac4b5a867 Mon Sep 17 00:00:00 2001 From: bytebarde Date: Wed, 3 Jan 2024 22:07:09 -0700 Subject: [PATCH 4/9] remove the copy line in the `CodeGenBlock` --- src/transformers/models/codegen/modeling_codegen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 6fc054254a48..1d3569feffa3 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -264,7 +264,6 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens return hidden_states -# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen class CodeGenBlock(nn.Module): def __init__(self, config): super().__init__() From 0c31cb3fd8a1d9d08e218550c5f93ed128702c9c Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Sat, 27 Jan 2024 21:54:47 -0700 Subject: [PATCH 5/9] address copy mechanism --- src/transformers/models/codegen/modeling_codegen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 1d3569feffa3..b73e5718ef5f 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -264,7 +264,9 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens return hidden_states +# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen class CodeGenBlock(nn.Module): + # Ignore copy def __init__(self, config): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd From def626ef1699b58e78507161fe32433290454e59 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 30 Jan 2024 02:54:27 +0100 Subject: [PATCH 6/9] Update src/transformers/models/gptj/modeling_gptj.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gptj/modeling_gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2c0ba30a37b8..eb912be44f71 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -375,7 +375,7 @@ def forward( else: present = None - # The Falsh attention requires the input to have the shape + # The Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we need to keep the original shape for query and key, and reshape value # to have the correct shape. From af0752eae3095ca0c1ff0b30db4450ddf97a01f2 Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Thu, 8 Feb 2024 23:40:03 -0700 Subject: [PATCH 7/9] Add GPTJ attention classes --- src/transformers/models/gptj/modeling_gptj.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 910aba8a34ee..0ea7cf663ab7 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -536,6 +536,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +GPTJ_ATTENTION_CLASSES = { + "eager": GPTJAttention, + "flash_attention_2": GPTJFlashAttention2, +} + + class GPTJMLP(nn.Module): def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim super().__init__() @@ -560,11 +566,7 @@ def __init__(self, config): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = ( - GPTJFlashAttention2(config) - if config._attn_implementation == "flash_attention_2" - else GPTJAttention(config) - ) + self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config) self.mlp = GPTJMLP(inner_dim, config) def forward( From cb265c7390e865ddb4b322a09d83a6768d0644e6 Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Fri, 8 Mar 2024 07:19:12 -0700 Subject: [PATCH 8/9] add expected outputs in the gptj test --- tests/models/gptj/test_modeling_gptj.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index f9699f3dd600..fd88b85a13e4 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -19,8 +19,9 @@ import pytest -from transformers import GPTJConfig, is_torch_available +from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available from transformers.testing_utils import ( + require_bitsandbytes, require_flash_attn, require_torch, require_torch_gpu, @@ -529,45 +530,41 @@ def test_model_from_pretrained(self): @require_flash_attn @require_torch_gpu + @require_bitsandbytes @pytest.mark.flash_attn_test @slow def test_flash_attn_2_generate_padding_right(self): """ Overwritting the common test as the test is flaky on tiny models """ - model = GPTJForCausalLM.from_pretrained( - "EleutherAI/gpt-j-6b", - load_in_4bit=True, - device_map={"": 0}, - revision="float16", - torch_dtype=torch.float16, - ) - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") texts = ["hi", "Hello this is a very long sentence"] + expected_outputs = [ + "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the", + "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it", + ] tokenizer.padding_side = "right" tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = tokenizer.batch_decode(output_native) + quantization_config = BitsAndBytesConfig(load_in_4bit=True) model = GPTJForCausalLM.from_pretrained( "EleutherAI/gpt-j-6b", - load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2", revision="float16", torch_dtype=torch.float16, + quantization_config=quantization_config, ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_fa_2 = tokenizer.batch_decode(output_fa_2) - self.assertListEqual(output_native, output_fa_2) + self.assertListEqual(expected_outputs, output_fa_2) @require_torch From 2b489b065e1ba8ededf5dd1704d646e1e596ec99 Mon Sep 17 00:00:00 2001 From: bytebarde <154845754+bytebarde@users.noreply.github.com> Date: Tue, 12 Mar 2024 12:16:35 -0600 Subject: [PATCH 9/9] Ensure repo consistency with 'make fix-copies' --- src/transformers/models/gptj/modeling_gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 0ea7cf663ab7..c495d281db5d 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -68,7 +68,7 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens,