From 551de28102ce3e3ef3128f2d074bfb7e0de653fd Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 24 Oct 2024 08:02:23 -0400 Subject: [PATCH] Even better V-Llama accuracy (#881) * Even better vllama * Repeat --- .../src/vision_models/mllama/text.rs | 85 +++++++++---------- 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/mistralrs-core/src/vision_models/mllama/text.rs b/mistralrs-core/src/vision_models/mllama/text.rs index beda900ed..e1ff7e171 100644 --- a/mistralrs-core/src/vision_models/mllama/text.rs +++ b/mistralrs-core/src/vision_models/mllama/text.rs @@ -9,7 +9,7 @@ use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLin use crate::{ attention::SdpaParams, device_map::DeviceMapper, - layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, Sdpa}, + layers::{CausalMasker, Llama3RotaryEmbedding, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata}, @@ -205,9 +205,20 @@ impl MLlamaTextSelfAttention { (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; let mut attn_output = Sdpa - .run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)? + .run_attention( + &q.contiguous()?.to_dtype(DType::F32)?, + &k.contiguous()?.to_dtype(DType::F32)?, + &v.contiguous()?.to_dtype(DType::F32)?, + attention_mask + .map(|m| m.to_dtype(DType::F32).unwrap()) + .as_ref(), + None, + &self.sdpa_params, + )? .transpose(1, 2)? - .reshape((bs, q_len, ()))?; + .contiguous()? + .reshape((bs, q_len, ()))? + .to_dtype(q.dtype())?; if let Some(t) = self.q_proj.quantized_act_type() { attn_output = attn_output.to_dtype(t)?; @@ -300,6 +311,7 @@ struct MLlamaTextCrossAttention { num_heads: usize, num_kv_heads: usize, head_dim: usize, + sdpa_params: SdpaParams, } impl MLlamaTextCrossAttention { @@ -347,6 +359,13 @@ impl MLlamaTextCrossAttention { num_heads: cfg.num_attention_heads, num_kv_heads: cfg.num_key_value_heads, head_dim: cfg.head_dim(), + sdpa_params: SdpaParams { + n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads, + use_flash_attn: false, + softcap: None, + softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(), + sliding_window: None, + }, }) } @@ -396,9 +415,6 @@ impl MLlamaTextCrossAttention { .reshape((bs, (), self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - k = repeat_kv(k.clone(), self.num_heads / self.num_kv_heads)?.contiguous()?; - v = repeat_kv(v.clone(), self.num_heads / self.num_kv_heads)?.contiguous()?; - (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?; (k, v) } else if let Some((k_cache, v_cache)) = kv_cache { @@ -407,43 +423,26 @@ impl MLlamaTextCrossAttention { candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!") }; - let mut attn_output = { - let att = match attention_mask { - Some(m) => { - let mut out = m.to_dtype(DType::F32)?.repeat((1, self.num_heads, 1, 1))?; - q.contiguous()? - .to_dtype(DType::F32)? - .matmul_with_alpha_beta( - &k.t()?.contiguous()?.to_dtype(DType::F32)?, - &mut out, - Some(1. / (self.head_dim as f64).sqrt()), - )?; - out.to_dtype(q.dtype())? - } - None => MatMul.matmul_affine_div( - &q.contiguous()?, - &k.t()?.contiguous()?, - (self.head_dim as f64).sqrt(), - )?, - }; - // let att = MatMul.matmul_affine_div( - // &q.contiguous()?, - // &k.t()?.contiguous()?, - // (self.head_dim as f64).sqrt(), - // )?; - - // let att = match attention_mask { - // Some(m) => att.broadcast_add(m)?, - // None => att, - // }; - let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? - .to_dtype(q.dtype())?; - // Convert to contiguous as matmul doesn't support strided vs for now. - MatMul - .matmul(&att, &v.contiguous()?)? - .transpose(1, 2)? - .reshape((bs, q_len, ()))? - }; + let mut attn_output = Sdpa + .run_attention( + &q.contiguous()?.to_dtype(DType::F32)?, + &k.contiguous()?.to_dtype(DType::F32)?, + &v.contiguous()?.to_dtype(DType::F32)?, + attention_mask + .map(|m| { + m.to_dtype(DType::F32) + .unwrap() + .repeat((1, self.num_heads, 1, 1)) + .unwrap() + }) + .as_ref(), + None, + &self.sdpa_params, + )? + .transpose(1, 2)? + .contiguous()? + .reshape((bs, q_len, ()))? + .to_dtype(q.dtype())?; if let Some(t) = self.q_proj.quantized_act_type() { attn_output = attn_output.to_dtype(t)?;