Skip to content

Commit

Permalink
Add a faster and nicer fix
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored and Nicolas Aveline committed Nov 7, 2024
1 parent b5727a1 commit f7cc189
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 37 deletions.
2 changes: 1 addition & 1 deletion mistralrs-core/src/vision_models/mllama/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ pub(crate) struct MLlamaSpecificArgs {

impl VisionModel for MLlamaModel {
fn cache(&self) -> &Cache {
&self.language_model.self_attn_cache
&self.language_model.cache
}
fn config(&self) -> &ModelConfigMetadata {
&self.language_model.cfg
Expand Down
111 changes: 80 additions & 31 deletions mistralrs-core/src/vision_models/mllama/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,52 @@

use std::{collections::HashMap, sync::Arc};

use candle_core::{Device, IndexOp, Result, Tensor};
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Activation, Embedding, Module, VarBuilder};
use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLinear};

use crate::{
attention::SdpaParams,
device_map::DeviceMapper,
layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, RmsNorm, Sdpa},
layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata},
utils::unvarbuilder::UnVarBuilder,
utils::unvarbuilder::{ToTensors, UnVarBuilder},
};

use super::config::MLlamaTextConfig;

struct MLlamaRmsNorm {
w: Tensor,
eps: f64,
}

impl MLlamaRmsNorm {
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
Ok(Self {
w: vb.get((size,), "weight")?,
eps,
})
}
}

impl Module for MLlamaRmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let initial_type = xs.dtype();
let mut xs = xs.to_dtype(DType::F32)?;
let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?;
xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?;
xs.to_dtype(initial_type)?.broadcast_mul(&self.w)
}
}

impl ToTensors for MLlamaRmsNorm {
fn to_tensors(&self) -> HashMap<String, Tensor> {
HashMap::from_iter([("weight".to_string(), self.w.clone())])
}
}

struct MLlamaTextMlp {
gate_proj: Arc<dyn QuantMethod>,
up_proj: Arc<dyn QuantMethod>,
Expand Down Expand Up @@ -193,8 +223,8 @@ impl MLlamaTextSelfAttention {
struct MLlamaSelfAttentionDecoderLayer {
attn: MLlamaTextSelfAttention,
mlp: MLlamaTextMlp,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
input_layernorm: MLlamaRmsNorm,
post_attention_layernorm: MLlamaRmsNorm,
}

impl MLlamaSelfAttentionDecoderLayer {
Expand All @@ -207,12 +237,12 @@ impl MLlamaSelfAttentionDecoderLayer {
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = RmsNorm::new(
let input_layernorm = MLlamaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
let post_attention_layernorm = MLlamaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
Expand Down Expand Up @@ -265,8 +295,8 @@ struct MLlamaTextCrossAttention {
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
q_norm: RmsNorm,
k_norm: RmsNorm,
q_norm: MLlamaRmsNorm,
k_norm: MLlamaRmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
Expand Down Expand Up @@ -304,12 +334,12 @@ impl MLlamaTextCrossAttention {
&cfg.quantization_config,
vb.pp("o_proj"),
)?,
q_norm: RmsNorm::new(
q_norm: MLlamaRmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("q_norm"), false),
)?,
k_norm: RmsNorm::new(
k_norm: MLlamaRmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("k_norm"), false),
Expand Down Expand Up @@ -378,17 +408,36 @@ impl MLlamaTextCrossAttention {
};

let mut attn_output = {
let att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(self.head_dim as f64).sqrt(),
)?;

let att = match attention_mask {
Some(m) => att.broadcast_add(m)?,
None => att,
Some(m) => {
let mut out = m.to_dtype(DType::F32)?.repeat((1, self.num_heads, 1, 1))?;
q.contiguous()?
.to_dtype(DType::F32)?
.matmul_with_alpha_beta(
&k.t()?.contiguous()?.to_dtype(DType::F32)?,
&mut out,
Some(1. / (self.head_dim as f64).sqrt()),
)?;
out.to_dtype(q.dtype())?
}
None => MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(self.head_dim as f64).sqrt(),
)?,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// let att = MatMul.matmul_affine_div(
// &q.contiguous()?,
// &k.t()?.contiguous()?,
// (self.head_dim as f64).sqrt(),
// )?;

// let att = match attention_mask {
// Some(m) => att.broadcast_add(m)?,
// None => att,
// };
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul
.matmul(&att, &v.contiguous()?)?
Expand All @@ -412,8 +461,8 @@ struct MLlamaCrossAttentionDecoderLayer {
attn_gate: Tensor,
mlp: MLlamaTextMlp,
mlp_gate: Tensor,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
input_layernorm: MLlamaRmsNorm,
post_attention_layernorm: MLlamaRmsNorm,
}

impl MLlamaCrossAttentionDecoderLayer {
Expand All @@ -425,12 +474,12 @@ impl MLlamaCrossAttentionDecoderLayer {
loading_isq: bool,
) -> Result<Self> {
let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = RmsNorm::new(
let input_layernorm = MLlamaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
let post_attention_layernorm = MLlamaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
Expand Down Expand Up @@ -495,10 +544,10 @@ enum MLlamaDecoderLayer {
pub(super) struct MLlamaTextModel {
embed_tokens: Embedding,
lm_head: Arc<dyn QuantMethod>,
norm: RmsNorm,
norm: MLlamaRmsNorm,
layers: Vec<MLlamaDecoderLayer>,
pub(crate) cfg: ModelConfigMetadata,
pub(crate) self_attn_cache: Cache,
pub(crate) cache: Cache,
pub(crate) device: Device,
pub(crate) max_position_embeddings: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
Expand Down Expand Up @@ -541,7 +590,7 @@ impl MLlamaTextModel {

let vb = vb.pp("model");

let norm = RmsNorm::new(
let norm = MLlamaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("norm"), false),
Expand Down Expand Up @@ -608,7 +657,7 @@ impl MLlamaTextModel {
sliding_window: None,
head_dim: None,
},
self_attn_cache: Cache::new(cfg.num_hidden_layers, false),
cache: Cache::new(cfg.num_hidden_layers, false),
device: normal_loading_metadata.real_device,
max_position_embeddings: cfg.max_position_embeddings,
mapper,
Expand All @@ -628,7 +677,7 @@ impl MLlamaTextModel {
) -> Result<Tensor> {
let mut hidden_states = self.embed_tokens.forward(input_ids)?;

let mut self_cache = self.self_attn_cache.lock();
let mut cache = self.cache.lock();
let self_mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
&seqlen_offsets as &dyn PastKvLenCache,
Expand All @@ -645,7 +694,7 @@ impl MLlamaTextModel {
self_mask.as_ref(),
seqlen_offsets,
start_offsets_kernel.clone(),
&mut self_cache[i],
&mut cache[i],
)?;
}
MLlamaDecoderLayer::CrossAttn(attn) => {
Expand All @@ -660,7 +709,7 @@ impl MLlamaTextModel {
cross_attn_states,
cross_attention_mask,
full_text_row_masked_out_mask,
&mut self_cache[i],
&mut cache[i],
)?;
}
}
Expand Down
13 changes: 8 additions & 5 deletions mistralrs-core/src/vision_models/mllama/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,19 @@ impl MLlamaVisionAttention {

let attn_output = Sdpa
.run_attention(
&q.contiguous()?,
&k.contiguous()?,
&v.contiguous()?,
attention_mask,
&q.contiguous()?.to_dtype(DType::F32)?,
&k.contiguous()?.to_dtype(DType::F32)?,
&v.contiguous()?.to_dtype(DType::F32)?,
attention_mask
.map(|m| m.to_dtype(DType::F32).unwrap())
.as_ref(),
None,
&self.sdpa_params,
)?
.transpose(1, 2)?
.contiguous()?
.reshape((bs, q_sq, ()))?;
.reshape((bs, q_sq, ()))?
.to_dtype(q.dtype())?;

self.o_proj.forward(&attn_output)
}
Expand Down
4 changes: 4 additions & 0 deletions mistralrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,7 @@ required-features = []
[[example]]
name = "llama_vision"
required-features = []

[[example]]
name = "llama_vision_multiturn"
required-features = []
93 changes: 93 additions & 0 deletions mistralrs/examples/llama_vision_multiturn/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use anyhow::Result;
use mistralrs::{
RequestBuilder, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder,
};

const MODEL_ID: &str = "meta-llama/Llama-3.2-11B-Vision-Instruct";

#[tokio::main]
async fn main() -> Result<()> {
let model = VisionModelBuilder::new(MODEL_ID, VisionLoaderType::VLlama)
.with_logging()
.with_isq(mistralrs::IsqType::Q8_0)
.build()
.await?;

let mut messages = VisionMessages::new().add_message(TextMessageRole::User, "Hello!");

let resp = model
.send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100))
.await?
.choices[0]
.message
.content
.clone()
.unwrap();
println!("\n\n{resp}");
messages = messages.add_message(TextMessageRole::Assistant, resp);

let bytes = match reqwest::blocking::get(
// "https://s3.amazonaws.com/cdn.tulips.com/images/large/Timeless-Tulip.jpg",
"https://niche-museums.imgix.net/pioneer-history.jpeg",
) {
Ok(http_resp) => http_resp.bytes()?.to_vec(),
Err(e) => anyhow::bail!(e),
};
let image = image::load_from_memory(&bytes)?;

messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image);
let resp = model
.send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100))
.await?
.choices[0]
.message
.content
.clone()
.unwrap();
println!("\n\n{resp}");
messages = messages.add_message(TextMessageRole::Assistant, resp);

let bytes = match reqwest::blocking::get(
"https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg"
) {
Ok(http_resp) => http_resp.bytes()?.to_vec(),
Err(e) => anyhow::bail!(e),
};
let image = image::load_from_memory(&bytes)?;

messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image);
let resp = model
.send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100))
.await?
.choices[0]
.message
.content
.clone()
.unwrap();
println!("\n\n{resp}");
messages = messages.add_message(TextMessageRole::Assistant, resp);

let bytes =
match reqwest::blocking::get("https://cdn.britannica.com/79/4679-050-BC127236/Titanic.jpg")
{
Ok(http_resp) => http_resp.bytes()?.to_vec(),
Err(e) => anyhow::bail!(e),
};
let image = image::load_from_memory(&bytes)?;

messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image);
let resp = model
.send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100))
.await?
.choices[0]
.message
.content
.clone()
.unwrap();
println!("\n\nModel response*: {resp}");
messages = messages.add_message(TextMessageRole::Assistant, resp);

println!("Final chat history: {messages:?}");

Ok(())
}

0 comments on commit f7cc189

Please sign in to comment.