Skip to content

Commit

Permalink
Even better V-Llama accuracy (#881)
Browse files Browse the repository at this point in the history
* Even better vllama

* Repeat
  • Loading branch information
EricLBuehler authored Oct 24, 2024
1 parent 7a67dda commit 6aa2a51
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLin
use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, Sdpa},
layers::{CausalMasker, Llama3RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata},
Expand Down Expand Up @@ -205,9 +205,20 @@ impl MLlamaTextSelfAttention {
(k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;

let mut attn_output = Sdpa
.run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)?
.run_attention(
&q.contiguous()?.to_dtype(DType::F32)?,
&k.contiguous()?.to_dtype(DType::F32)?,
&v.contiguous()?.to_dtype(DType::F32)?,
attention_mask
.map(|m| m.to_dtype(DType::F32).unwrap())
.as_ref(),
None,
&self.sdpa_params,
)?
.transpose(1, 2)?
.reshape((bs, q_len, ()))?;
.contiguous()?
.reshape((bs, q_len, ()))?
.to_dtype(q.dtype())?;

if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
Expand Down Expand Up @@ -300,6 +311,7 @@ struct MLlamaTextCrossAttention {
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
sdpa_params: SdpaParams,
}

impl MLlamaTextCrossAttention {
Expand Down Expand Up @@ -347,6 +359,13 @@ impl MLlamaTextCrossAttention {
num_heads: cfg.num_attention_heads,
num_kv_heads: cfg.num_key_value_heads,
head_dim: cfg.head_dim(),
sdpa_params: SdpaParams {
n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
use_flash_attn: false,
softcap: None,
softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(),
sliding_window: None,
},
})
}

Expand Down Expand Up @@ -396,9 +415,6 @@ impl MLlamaTextCrossAttention {
.reshape((bs, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;

k = repeat_kv(k.clone(), self.num_heads / self.num_kv_heads)?.contiguous()?;
v = repeat_kv(v.clone(), self.num_heads / self.num_kv_heads)?.contiguous()?;

(k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
(k, v)
} else if let Some((k_cache, v_cache)) = kv_cache {
Expand All @@ -407,43 +423,26 @@ impl MLlamaTextCrossAttention {
candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
};

let mut attn_output = {
let att = match attention_mask {
Some(m) => {
let mut out = m.to_dtype(DType::F32)?.repeat((1, self.num_heads, 1, 1))?;
q.contiguous()?
.to_dtype(DType::F32)?
.matmul_with_alpha_beta(
&k.t()?.contiguous()?.to_dtype(DType::F32)?,
&mut out,
Some(1. / (self.head_dim as f64).sqrt()),
)?;
out.to_dtype(q.dtype())?
}
None => MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(self.head_dim as f64).sqrt(),
)?,
};
// let att = MatMul.matmul_affine_div(
// &q.contiguous()?,
// &k.t()?.contiguous()?,
// (self.head_dim as f64).sqrt(),
// )?;

// let att = match attention_mask {
// Some(m) => att.broadcast_add(m)?,
// None => att,
// };
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul
.matmul(&att, &v.contiguous()?)?
.transpose(1, 2)?
.reshape((bs, q_len, ()))?
};
let mut attn_output = Sdpa
.run_attention(
&q.contiguous()?.to_dtype(DType::F32)?,
&k.contiguous()?.to_dtype(DType::F32)?,
&v.contiguous()?.to_dtype(DType::F32)?,
attention_mask
.map(|m| {
m.to_dtype(DType::F32)
.unwrap()
.repeat((1, self.num_heads, 1, 1))
.unwrap()
})
.as_ref(),
None,
&self.sdpa_params,
)?
.transpose(1, 2)?
.contiguous()?
.reshape((bs, q_len, ()))?
.to_dtype(q.dtype())?;

if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
Expand Down

0 comments on commit 6aa2a51

Please sign in to comment.