Skip to content

Commit

Permalink
Merge pull request #15 from jacklanda/fix-use-cache-error
Browse files Browse the repository at this point in the history
[Fix] Fix the Error of q, k, and v states must have the same dtype when using flash attention forward.
  • Loading branch information
gitsailor5 authored May 30, 2024
2 parents eb0586d + 21cd706 commit c73a047
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mergoo/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,26 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

# check if any state of q, k, and v has different dtype
if not (query_states.dtype == key_states.dtype == value_states.dtype):
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
Expand Down

0 comments on commit c73a047

Please sign in to comment.