From aba6b66caf708a33eaf88cc0fc66d06fb0f08a08 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 29 May 2024 00:52:37 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=9A=91=20Fix:=20cast=20back=20the=20d?= =?UTF-8?q?type=20of=20q,=20k,=20and=20v=20to=20the=20same=20target=20dtyp?= =?UTF-8?q?e.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mergoo/models/modeling_llama.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mergoo/models/modeling_llama.py b/mergoo/models/modeling_llama.py index 8e35045..7c58c34 100644 --- a/mergoo/models/modeling_llama.py +++ b/mergoo/models/modeling_llama.py @@ -541,6 +541,26 @@ def _flash_attention_forward( cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + attn_output_unpad = flash_attn_varlen_func( query_states, key_states, From 21cd706bf9e1d0cfa6deaab0fb343a2b43183a25 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 29 May 2024 11:05:30 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20Add=20dtype=20checking=20for=20?= =?UTF-8?q?q,=20k,=20and=20v=20states=20before=20auto=20casting.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mergoo/models/modeling_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergoo/models/modeling_llama.py b/mergoo/models/modeling_llama.py index 7c58c34..9f32eff 100644 --- a/mergoo/models/modeling_llama.py +++ b/mergoo/models/modeling_llama.py @@ -541,8 +541,8 @@ def _flash_attention_forward( cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - input_dtype = query_states.dtype - if input_dtype == torch.float32: + # check if any state of q, k, and v has different dtype + if not (query_states.dtype == key_states.dtype == value_states.dtype): if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized