From f7cc1893bc4c21d7d768d21c868a9ec2243505c8 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 23 Oct 2024 20:29:33 -0400 Subject: [PATCH] Add a faster and nicer fix --- .../src/vision_models/mllama/mod.rs | 2 +- .../src/vision_models/mllama/text.rs | 111 +++++++++++++----- .../src/vision_models/mllama/vision.rs | 13 +- mistralrs/Cargo.toml | 4 + .../examples/llama_vision_multiturn/main.rs | 93 +++++++++++++++ 5 files changed, 186 insertions(+), 37 deletions(-) create mode 100644 mistralrs/examples/llama_vision_multiturn/main.rs diff --git a/mistralrs-core/src/vision_models/mllama/mod.rs b/mistralrs-core/src/vision_models/mllama/mod.rs index 2a934a4b19..d07e0e4c38 100644 --- a/mistralrs-core/src/vision_models/mllama/mod.rs +++ b/mistralrs-core/src/vision_models/mllama/mod.rs @@ -196,7 +196,7 @@ pub(crate) struct MLlamaSpecificArgs { impl VisionModel for MLlamaModel { fn cache(&self) -> &Cache { - &self.language_model.self_attn_cache + &self.language_model.cache } fn config(&self) -> &ModelConfigMetadata { &self.language_model.cfg diff --git a/mistralrs-core/src/vision_models/mllama/text.rs b/mistralrs-core/src/vision_models/mllama/text.rs index 2e06a1ae1f..beda900ed4 100644 --- a/mistralrs-core/src/vision_models/mllama/text.rs +++ b/mistralrs-core/src/vision_models/mllama/text.rs @@ -2,22 +2,52 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{Device, IndexOp, Result, Tensor}; +use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Activation, Embedding, Module, VarBuilder}; use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ attention::SdpaParams, device_map::DeviceMapper, - layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, RmsNorm, Sdpa}, + layers::{repeat_kv, CausalMasker, Llama3RotaryEmbedding, MatMul, Sdpa}, layers_masker::PastKvLenCache, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata}, - utils::unvarbuilder::UnVarBuilder, + utils::unvarbuilder::{ToTensors, UnVarBuilder}, }; use super::config::MLlamaTextConfig; +struct MLlamaRmsNorm { + w: Tensor, + eps: f64, +} + +impl MLlamaRmsNorm { + pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { + Ok(Self { + w: vb.get((size,), "weight")?, + eps, + }) + } +} + +impl Module for MLlamaRmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + let initial_type = xs.dtype(); + let mut xs = xs.to_dtype(DType::F32)?; + let var = xs.powf(2.)?.mean_keepdim(D::Minus1)?; + xs = xs.broadcast_mul(&(&var + self.eps)?.recip()?.sqrt()?)?; + xs.to_dtype(initial_type)?.broadcast_mul(&self.w) + } +} + +impl ToTensors for MLlamaRmsNorm { + fn to_tensors(&self) -> HashMap { + HashMap::from_iter([("weight".to_string(), self.w.clone())]) + } +} + struct MLlamaTextMlp { gate_proj: Arc, up_proj: Arc, @@ -193,8 +223,8 @@ impl MLlamaTextSelfAttention { struct MLlamaSelfAttentionDecoderLayer { attn: MLlamaTextSelfAttention, mlp: MLlamaTextMlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, + input_layernorm: MLlamaRmsNorm, + post_attention_layernorm: MLlamaRmsNorm, } impl MLlamaSelfAttentionDecoderLayer { @@ -207,12 +237,12 @@ impl MLlamaSelfAttentionDecoderLayer { loading_isq: bool, ) -> Result { let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?; - let input_layernorm = RmsNorm::new( + let input_layernorm = MLlamaRmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("input_layernorm"), false), )?; - let post_attention_layernorm = RmsNorm::new( + let post_attention_layernorm = MLlamaRmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false), @@ -265,8 +295,8 @@ struct MLlamaTextCrossAttention { k_proj: Arc, v_proj: Arc, o_proj: Arc, - q_norm: RmsNorm, - k_norm: RmsNorm, + q_norm: MLlamaRmsNorm, + k_norm: MLlamaRmsNorm, num_heads: usize, num_kv_heads: usize, head_dim: usize, @@ -304,12 +334,12 @@ impl MLlamaTextCrossAttention { &cfg.quantization_config, vb.pp("o_proj"), )?, - q_norm: RmsNorm::new( + q_norm: MLlamaRmsNorm::new( cfg.head_dim(), cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("q_norm"), false), )?, - k_norm: RmsNorm::new( + k_norm: MLlamaRmsNorm::new( cfg.head_dim(), cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("k_norm"), false), @@ -378,17 +408,36 @@ impl MLlamaTextCrossAttention { }; let mut attn_output = { - let att = MatMul.matmul_affine_div( - &q.contiguous()?, - &k.t()?.contiguous()?, - (self.head_dim as f64).sqrt(), - )?; - let att = match attention_mask { - Some(m) => att.broadcast_add(m)?, - None => att, + Some(m) => { + let mut out = m.to_dtype(DType::F32)?.repeat((1, self.num_heads, 1, 1))?; + q.contiguous()? + .to_dtype(DType::F32)? + .matmul_with_alpha_beta( + &k.t()?.contiguous()?.to_dtype(DType::F32)?, + &mut out, + Some(1. / (self.head_dim as f64).sqrt()), + )?; + out.to_dtype(q.dtype())? + } + None => MatMul.matmul_affine_div( + &q.contiguous()?, + &k.t()?.contiguous()?, + (self.head_dim as f64).sqrt(), + )?, }; - let att = candle_nn::ops::softmax_last_dim(&att)?; + // let att = MatMul.matmul_affine_div( + // &q.contiguous()?, + // &k.t()?.contiguous()?, + // (self.head_dim as f64).sqrt(), + // )?; + + // let att = match attention_mask { + // Some(m) => att.broadcast_add(m)?, + // None => att, + // }; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; // Convert to contiguous as matmul doesn't support strided vs for now. MatMul .matmul(&att, &v.contiguous()?)? @@ -412,8 +461,8 @@ struct MLlamaCrossAttentionDecoderLayer { attn_gate: Tensor, mlp: MLlamaTextMlp, mlp_gate: Tensor, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, + input_layernorm: MLlamaRmsNorm, + post_attention_layernorm: MLlamaRmsNorm, } impl MLlamaCrossAttentionDecoderLayer { @@ -425,12 +474,12 @@ impl MLlamaCrossAttentionDecoderLayer { loading_isq: bool, ) -> Result { let mlp = MLlamaTextMlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?; - let input_layernorm = RmsNorm::new( + let input_layernorm = MLlamaRmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("input_layernorm"), false), )?; - let post_attention_layernorm = RmsNorm::new( + let post_attention_layernorm = MLlamaRmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false), @@ -495,10 +544,10 @@ enum MLlamaDecoderLayer { pub(super) struct MLlamaTextModel { embed_tokens: Embedding, lm_head: Arc, - norm: RmsNorm, + norm: MLlamaRmsNorm, layers: Vec, pub(crate) cfg: ModelConfigMetadata, - pub(crate) self_attn_cache: Cache, + pub(crate) cache: Cache, pub(crate) device: Device, pub(crate) max_position_embeddings: usize, mapper: Box, @@ -541,7 +590,7 @@ impl MLlamaTextModel { let vb = vb.pp("model"); - let norm = RmsNorm::new( + let norm = MLlamaRmsNorm::new( cfg.hidden_size, cfg.rms_norm_eps, mapper.set_nm_device(vb.pp("norm"), false), @@ -608,7 +657,7 @@ impl MLlamaTextModel { sliding_window: None, head_dim: None, }, - self_attn_cache: Cache::new(cfg.num_hidden_layers, false), + cache: Cache::new(cfg.num_hidden_layers, false), device: normal_loading_metadata.real_device, max_position_embeddings: cfg.max_position_embeddings, mapper, @@ -628,7 +677,7 @@ impl MLlamaTextModel { ) -> Result { let mut hidden_states = self.embed_tokens.forward(input_ids)?; - let mut self_cache = self.self_attn_cache.lock(); + let mut cache = self.cache.lock(); let self_mask = CausalMasker.make_causal_mask_as_attn_bias( input_ids, &seqlen_offsets as &dyn PastKvLenCache, @@ -645,7 +694,7 @@ impl MLlamaTextModel { self_mask.as_ref(), seqlen_offsets, start_offsets_kernel.clone(), - &mut self_cache[i], + &mut cache[i], )?; } MLlamaDecoderLayer::CrossAttn(attn) => { @@ -660,7 +709,7 @@ impl MLlamaTextModel { cross_attn_states, cross_attention_mask, full_text_row_masked_out_mask, - &mut self_cache[i], + &mut cache[i], )?; } } diff --git a/mistralrs-core/src/vision_models/mllama/vision.rs b/mistralrs-core/src/vision_models/mllama/vision.rs index c24308f2d4..4f487d4c13 100644 --- a/mistralrs-core/src/vision_models/mllama/vision.rs +++ b/mistralrs-core/src/vision_models/mllama/vision.rs @@ -195,16 +195,19 @@ impl MLlamaVisionAttention { let attn_output = Sdpa .run_attention( - &q.contiguous()?, - &k.contiguous()?, - &v.contiguous()?, - attention_mask, + &q.contiguous()?.to_dtype(DType::F32)?, + &k.contiguous()?.to_dtype(DType::F32)?, + &v.contiguous()?.to_dtype(DType::F32)?, + attention_mask + .map(|m| m.to_dtype(DType::F32).unwrap()) + .as_ref(), None, &self.sdpa_params, )? .transpose(1, 2)? .contiguous()? - .reshape((bs, q_sq, ()))?; + .reshape((bs, q_sq, ()))? + .to_dtype(q.dtype())?; self.o_proj.forward(&attn_output) } diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index 2aec4cf6bb..e6a8be8a4d 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -132,3 +132,7 @@ required-features = [] [[example]] name = "llama_vision" required-features = [] + +[[example]] +name = "llama_vision_multiturn" +required-features = [] diff --git a/mistralrs/examples/llama_vision_multiturn/main.rs b/mistralrs/examples/llama_vision_multiturn/main.rs new file mode 100644 index 0000000000..3f06e9072d --- /dev/null +++ b/mistralrs/examples/llama_vision_multiturn/main.rs @@ -0,0 +1,93 @@ +use anyhow::Result; +use mistralrs::{ + RequestBuilder, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder, +}; + +const MODEL_ID: &str = "meta-llama/Llama-3.2-11B-Vision-Instruct"; + +#[tokio::main] +async fn main() -> Result<()> { + let model = VisionModelBuilder::new(MODEL_ID, VisionLoaderType::VLlama) + .with_logging() + .with_isq(mistralrs::IsqType::Q8_0) + .build() + .await?; + + let mut messages = VisionMessages::new().add_message(TextMessageRole::User, "Hello!"); + + let resp = model + .send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100)) + .await? + .choices[0] + .message + .content + .clone() + .unwrap(); + println!("\n\n{resp}"); + messages = messages.add_message(TextMessageRole::Assistant, resp); + + let bytes = match reqwest::blocking::get( + // "https://s3.amazonaws.com/cdn.tulips.com/images/large/Timeless-Tulip.jpg", + "https://niche-museums.imgix.net/pioneer-history.jpeg", + ) { + Ok(http_resp) => http_resp.bytes()?.to_vec(), + Err(e) => anyhow::bail!(e), + }; + let image = image::load_from_memory(&bytes)?; + + messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image); + let resp = model + .send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100)) + .await? + .choices[0] + .message + .content + .clone() + .unwrap(); + println!("\n\n{resp}"); + messages = messages.add_message(TextMessageRole::Assistant, resp); + + let bytes = match reqwest::blocking::get( + "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg" + ) { + Ok(http_resp) => http_resp.bytes()?.to_vec(), + Err(e) => anyhow::bail!(e), + }; + let image = image::load_from_memory(&bytes)?; + + messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image); + let resp = model + .send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100)) + .await? + .choices[0] + .message + .content + .clone() + .unwrap(); + println!("\n\n{resp}"); + messages = messages.add_message(TextMessageRole::Assistant, resp); + + let bytes = + match reqwest::blocking::get("https://cdn.britannica.com/79/4679-050-BC127236/Titanic.jpg") + { + Ok(http_resp) => http_resp.bytes()?.to_vec(), + Err(e) => anyhow::bail!(e), + }; + let image = image::load_from_memory(&bytes)?; + + messages = messages.add_vllama_image_message(TextMessageRole::User, "What is this?", image); + let resp = model + .send_chat_request(RequestBuilder::from(messages.clone()).set_sampler_max_len(100)) + .await? + .choices[0] + .message + .content + .clone() + .unwrap(); + println!("\n\nModel response*: {resp}"); + messages = messages.add_message(TextMessageRole::Assistant, resp); + + println!("Final chat history: {messages:?}"); + + Ok(()) +}