From a8c2b41a859bc937fd38a1e3ae035068f3293e3b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:12:22 -0400 Subject: [PATCH 1/9] Support multiple GGUF files (#379) * Move to gguf module * Add content abstraction for multiple gguf files * Fix test * Allow specifying and loading multiple gguf files * Update docs and examples * Print some info --- Cargo.toml | 2 +- README.md | 8 + mistralrs-core/src/device_map.rs | 1 + mistralrs-core/src/gguf/content.rs | 167 ++++++++++++++++++ .../src/{pipeline => gguf}/gguf_tokenizer.rs | 42 +++-- mistralrs-core/src/gguf/mod.rs | 5 + mistralrs-core/src/layers.rs | 11 +- mistralrs-core/src/lib.rs | 12 +- mistralrs-core/src/model_loader.rs | 16 +- mistralrs-core/src/models/quantized_llama.rs | 84 +++++---- mistralrs-core/src/models/quantized_phi2.rs | 52 +++--- mistralrs-core/src/models/quantized_phi3.rs | 74 ++++---- mistralrs-core/src/pipeline/ggml.rs | 2 +- mistralrs-core/src/pipeline/gguf.rs | 95 +++------- mistralrs-core/src/pipeline/mod.rs | 3 +- mistralrs-core/src/pipeline/paths.rs | 16 +- mistralrs-core/src/toml_selector.rs | 17 +- mistralrs-core/src/utils/max_seq_len.rs | 20 --- mistralrs-core/src/utils/mod.rs | 1 - mistralrs-core/src/utils/model_config.rs | 40 ++--- .../src/xlora_models/quantized_llama.rs | 84 +++++---- .../src/xlora_models/quantized_phi3.rs | 52 +++--- mistralrs-pyo3/API.md | 8 +- mistralrs-pyo3/README.md | 28 +-- mistralrs-pyo3/mistralrs.pyi | 6 +- mistralrs-pyo3/src/lib.rs | 12 +- mistralrs-pyo3/src/which.rs | 7 +- mistralrs/examples/gguf_locally/main.rs | 2 +- mistralrs/examples/quantized/main.rs | 2 +- 29 files changed, 498 insertions(+), 371 deletions(-) create mode 100644 mistralrs-core/src/gguf/content.rs rename mistralrs-core/src/{pipeline => gguf}/gguf_tokenizer.rs (90%) create mode 100644 mistralrs-core/src/gguf/mod.rs delete mode 100644 mistralrs-core/src/utils/max_seq_len.rs diff --git a/Cargo.toml b/Cargo.toml index 8ecadfccf..280bdbf47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } futures = "0.3" clap = { version = "4.5.1", features = ["derive"] } -pyo3 = { version = "0.21.0", features = ["full", "extension-module"] } +pyo3 = { version = "0.21.0", features = ["full", "extension-module", "either"] } tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] } once_cell = "1.19.0" image = "0.25.1" diff --git a/README.md b/README.md index d63b0526a..b987238ba 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,14 @@ please consider using the method demonstrated in examples below, where the token **Supported GGUF tokenizer types** - `llama` +Some GGUF models are very large and are sharded into multiple files. Mistral.rs supports this, and to use it, delimit the `.gguf` filenames with a space as such: + +```bash +./mistralrs-server --chat-template gguf -m . -f "a.gguf b.gguf" +``` + +For the Python API, a list of strings is also accepted for this case. + ## Run To start a server serving Mistral GGUF on `localhost:1234`, diff --git a/mistralrs-core/src/device_map.rs b/mistralrs-core/src/device_map.rs index f0aacfdf6..23062ca7e 100644 --- a/mistralrs-core/src/device_map.rs +++ b/mistralrs-core/src/device_map.rs @@ -18,6 +18,7 @@ impl DeviceMapMetadata { host_layers: None, } } + /// A device mapper to not map device. pub fn dummy() -> Self { Self { device_layers: None, diff --git a/mistralrs-core/src/gguf/content.rs b/mistralrs-core/src/gguf/content.rs new file mode 100644 index 000000000..52e972434 --- /dev/null +++ b/mistralrs-core/src/gguf/content.rs @@ -0,0 +1,167 @@ +use std::fs; + +use anyhow::Context; +use candle_core::{ + quantized::{ + gguf_file::{self, Value}, + QTensor, + }, + Device, Result, +}; +use indexmap::IndexMap; +use tracing::info; + +use crate::{pipeline::GGUFArchitecture, DEBUG}; + +fn parse_gguf_value(value: &Value) -> String { + match value { + Value::Array(vs) => vs + .iter() + .map(parse_gguf_value) + .collect::>() + .join(", "), + Value::Bool(b) => b.to_string(), + Value::F32(x) => x.to_string(), + Value::F64(x) => x.to_string(), + Value::I8(x) => x.to_string(), + Value::I16(x) => x.to_string(), + Value::I32(x) => x.to_string(), + Value::I64(x) => x.to_string(), + Value::String(x) => x.to_string(), + Value::U8(x) => x.to_string(), + Value::U16(x) => x.to_string(), + Value::U32(x) => x.to_string(), + Value::U64(x) => x.to_string(), + } +} + +// Internal invariant: contents and readers must be paired. +/// This abstracts the files for a GGUF model and enables multiple files to be used. +pub struct Content<'a, R: std::io::Seek + std::io::Read> { + contents: Vec, + readers: &'a mut [&'a mut R], + arch: GGUFArchitecture, +} + +impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> { + /// Create a `Content` from a set of file readers. + pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result { + let mut contents = Vec::new(); + let n_readers = readers.len(); + for reader in readers.iter_mut() { + contents.push(gguf_file::Content::read(reader)?); + } + let n_splits = contents + .iter() + .filter_map(|ct| { + ct.metadata + .get("split.count") + .map(|val| val.to_u64().unwrap()) + }) + .collect::>(); + if n_splits.len() > 1 { + candle_core::bail!("Multiple contents have multiple `split.count` fields"); + } + #[allow(clippy::cast_possible_truncation)] + if !n_splits.is_empty() && n_readers != n_splits[0] as usize { + candle_core::bail!("Number of readers does not match the number of splits."); + } else if n_splits.len() == 1 { + info!("Model n splits: {}", n_splits[0]); + } + + let mut arch = None; + for ct in &contents { + if !ct.metadata.contains_key("general.architecture") { + continue; + } + + arch = Some( + ct.metadata["general.architecture"] + .to_string() + .context("Model metadata should have declared an architecture") + .and_then(GGUFArchitecture::from_value) + .unwrap(), + ); + } + let arch = arch.expect("GGUF files must specify `general.architecture`"); + Ok(Self { + contents, + readers, + arch, + }) + } + + pub fn arch(&self) -> GGUFArchitecture { + self.arch + } + + /// Retrieve a tensor, searching through each content. + pub fn tensor(&mut self, name: &str, device: &Device) -> Result { + for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) { + if let Some(tensor_info) = ct.tensor_infos.get(name) { + return tensor_info.read(reader, ct.tensor_data_offset, device); + } + } + candle_core::bail!("Cannot find tensor info for {name}") + } + + /// Print metadata for these contents. + /// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled. + pub fn print_metadata(&self) -> anyhow::Result<()> { + // Find the ct with general.architecture + let mut keys = Vec::new(); + let mut metadatas = Vec::new(); + let mut tensors = Vec::new(); + for ct in &self.contents { + keys.extend(ct.metadata.keys()); + metadatas.push(&ct.metadata); + + if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { + for (name, info) in &ct.tensor_infos { + tensors.push(format!( + "name = `{name}`, shape = {:?}, dtype = {:?}", + info.shape.clone(), + info.ggml_dtype + )); + } + } + } + + info!("Model config:"); + keys.sort(); + let mut output_keys = IndexMap::new(); + for name in keys { + if !name.contains("tokenizer") { + for metadata in &metadatas { + if let Some(val) = metadata.get(name) { + output_keys.insert(name, parse_gguf_value(val)); + } + } + } + } + for (name, val) in output_keys { + println!("{name}: {val}") + } + + if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { + fs::write( + "mistralrs_gguf_tensors.txt", + serde_json::to_string_pretty(&tensors).expect("Serialization failed."), + )?; + + info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); + } + + anyhow::Ok(()) + } + + /// Get metadata + pub fn get_metadata(&self, name: &str) -> Result<&Value> { + for content in &self.contents { + if let Some(v) = content.metadata.get(name) { + return Ok(v); + } + } + candle_core::bail!("Cannot find metadata for {name}") + } +} diff --git a/mistralrs-core/src/pipeline/gguf_tokenizer.rs b/mistralrs-core/src/gguf/gguf_tokenizer.rs similarity index 90% rename from mistralrs-core/src/pipeline/gguf_tokenizer.rs rename to mistralrs-core/src/gguf/gguf_tokenizer.rs index 4d0ce613c..363aaffb0 100644 --- a/mistralrs-core/src/pipeline/gguf_tokenizer.rs +++ b/mistralrs-core/src/gguf/gguf_tokenizer.rs @@ -1,7 +1,6 @@ use std::sync::atomic::Ordering; use anyhow::Result; -use candle_core::quantized::gguf_file::Content; use tokenizers::{ decoders::{self, byte_fallback::ByteFallback, fuse::Fuse, strip::Strip}, models::unigram::Unigram, @@ -12,27 +11,33 @@ use tracing::info; use crate::DEBUG; -pub struct ConversionResult { +use super::Content; + +pub struct GgufTokenizerConversion { pub tokenizer: Tokenizer, pub bos: Option, pub eos: Option, pub unk: Option, } -pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { - let model = content.metadata["tokenizer.ggml.model"] +// Convert GGUF tokenizer to tokenizer and metadata +pub fn convert_gguf_to_hf_tokenizer( + content: &Content<'_, R>, +) -> Result { + let model = content + .get_metadata("tokenizer.ggml.model")? .to_string() .expect("GGUF tokenizer model is not a string.") .clone(); - let tokens = content.metadata["tokenizer.ggml.tokens"] + let tokens = content + .get_metadata("tokenizer.ggml.tokens")? .to_vec() .expect("GGUF tokenizer tokens is not a vec.") .iter() .map(|t| t.to_string().expect("GGUF token is not a string.").clone()) .collect::>(); let added_tokens = content - .metadata - .get("tokenizer.ggml.added_tokens") + .get_metadata("tokenizer.ggml.added_tokens") .map(|items| { items .to_vec() @@ -45,7 +50,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result>() }); - let scores = content.metadata.get("tokenizer.ggml.scores").map(|items| { + let scores = content.get_metadata("tokenizer.ggml.scores").map(|items| { items .to_vec() .expect("GGUF tokenizer scores is not a vec.") @@ -53,7 +58,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result>() }); - let merges = content.metadata.get("tokenizer.ggml.merges").map(|items| { + let merges = content.get_metadata("tokenizer.ggml.merges").map(|items| { items .to_vec() .expect("GGUF tokenizer merges is not a vec.") @@ -63,15 +68,16 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result Result Result( - ct: &gguf_file::Content, - r: &mut R, + ct: &mut Content<'_, R>, name: &str, device: &Device, ) -> Result { - let w = ct.tensor(r, &format!("{name}.weight"), device)?; - let b = ct.tensor(r, &format!("{name}.bias"), device)?; + let w = ct.tensor(&format!("{name}.weight"), device)?; + let b = ct.tensor(&format!("{name}.bias"), device)?; let inner = QMatMul::from_qtensor(w)?; let bias = b.dequantize(device)?; Ok(Self { diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 8132071ad..da1757ab1 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -26,6 +26,7 @@ mod model_selected; pub use model_selected::ModelSelected; mod cublaslt; +pub mod gguf; pub mod layers; mod layers_masker; mod layers_utils; @@ -42,11 +43,12 @@ mod utils; mod xlora_models; pub use device_map::{DeviceMapMetadata, LayerDeviceMapper}; +pub use gguf::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion}; pub use pipeline::{ - GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, - GGUFSpecificConfig, GemmaLoader, LlamaLoader, Loader, LocalModelPaths, MistralLoader, - MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder, NormalLoaderType, - NormalSpecificConfig, Phi2Loader, Phi3Loader, Qwen2Loader, SpeculativeConfig, + GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, + GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, LlamaLoader, Loader, LocalModelPaths, + MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder, + NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, TokenSource, }; pub use request::{Constraint, Content, NormalRequest, Request, RequestMessage}; @@ -60,6 +62,8 @@ pub use toml_selector::{TomlLoaderArgs, TomlSelector}; /// `true` if `MISTRALRS_DEBUG=1` pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false); +/// Delimiter for GGUF multiple files in the CLI. +pub const GGUF_MULTI_FILE_DELIMITER: &str = " "; /// The MistralRs struct handles sending requests to the engine. /// It is the core multi-threaded component of mistral.rs, and uses `mspc` diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index 3d61eb62c..c36d19804 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -6,6 +6,7 @@ use crate::{ NormalSpecificConfig, }, Loader, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, + GGUF_MULTI_FILE_DELIMITER, }; pub struct LoaderBuilder { @@ -158,7 +159,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .build(), ModelSelected::XLoraGGUF { @@ -174,7 +178,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .with_xlora( xlora_model_id, @@ -198,7 +205,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .with_lora( adapters_model_id, diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index a5771038c..15bb3a58a 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -1,16 +1,16 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use candle_core::quantized::ggml_file; use candle_core::quantized::QMatMul; -use candle_core::quantized::{ggml_file, gguf_file}; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, RotaryEmbedding}; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention, }; use crate::pipeline::{extract_logits, Cache}; -use crate::utils::max_seq_len::get_gguf_max_seq_len; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; @@ -260,49 +260,50 @@ impl ModelConfig::FromGGML for ModelWeights { impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; verify_sanity_gguf( - md_get("general.architecture")?.to_string().unwrap(), + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), "llama", )?; // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") + let n_expert = ct + .get_metadata("llama.expert_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") + let n_expert_used = ct + .get_metadata("llama.expert_used_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + let head_count = ct.get_metadata("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("llama.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = ct + .get_metadata("llama.attention.layer_norm_rms_epsilon")? + .to_f32()?; - let max_seq_len = - get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize; + let max_seq_len = ct + .get_metadata("llama.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; - let rope_freq_base = md_get("llama.rope.freq_base") + let rope_freq_base = ct + .get_metadata("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let head_dim = embedding_length / head_count; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = QRmsNorm::new( - ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, - )?; - let output = ct.tensor(reader, "output.weight", device)?; + let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?; + let output = ct.tensor("output.weight", device)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { @@ -318,18 +319,14 @@ impl ModelConfig::FromGGUF for ModelWeights { DType::F32, )?; - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wo = - ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; - let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -337,15 +334,15 @@ impl ModelConfig::FromGGUF for ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -358,9 +355,8 @@ impl ModelConfig::FromGGUF for ModelWeights { experts, } }; - let attention_norm = - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?; layers.push(LayerWeights { attention_wq: QMatMul::from_qtensor(attention_wq)?, attention_wk: QMatMul::from_qtensor(attention_wk)?, diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index 7a752f787..83a61a9cc 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -1,15 +1,14 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::quantized::gguf_file; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm}; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::ScaledDotProductAttention; use crate::layers::{repeat_kv, CausalMasker, QLinear}; use crate::pipeline::{extract_logits, Cache}; -use crate::utils::max_seq_len::get_gguf_max_seq_len; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; @@ -145,53 +144,52 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - // Parameter extraction from metadata. - let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi2.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize; - let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64; - let max_seq_len = - get_gguf_max_seq_len(md_get("phi2.context_length"), MAX_SEQ_LEN as u64) as usize; + let head_count = ct.get_metadata("phi2.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi2.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi2.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi2.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi2.rope.dimension_count")?.to_u32()? as usize; + let ln_eps = ct + .get_metadata("phi2.attention.layer_norm_epsilon")? + .to_f32()? as f64; + + let max_seq_len = ct + .get_metadata("phi2.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; let output_norm = layer_norm( - ct.tensor(reader, "output_norm.weight", device)?, - ct.tensor(reader, "output_norm.bias", device)?, + ct.tensor("output_norm.weight", device)?, + ct.tensor("output_norm.bias", device)?, ln_eps, )?; - let output = QLinear::new(&ct, reader, "output", device)?; + let output = QLinear::new(&mut ct, "output", device)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; - let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; + let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?; + let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?; let mlp = Mlp { ffn_up, ffn_down }; let attn_norm = layer_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, - ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?, ln_eps, )?; layers.push(LayerWeights { - attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, - attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, + attn_qkv: QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?, + attn_output: QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?, attn_norm, mlp, n_head: head_count, diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 8149eea84..b461e5abd 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -1,13 +1,13 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention, }; use crate::pipeline::Cache; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; -use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; @@ -162,71 +162,63 @@ fn precomput_freqs_cis( impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?; + verify_sanity_gguf( + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), + "phi3", + )?; // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let context_window = md_get("phi3.context_length")?.to_u32()? as usize; + let head_count = ct.get_metadata("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi3.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi3.embedding_length")?.to_u32()? as usize; + let i_size = ct.get_metadata("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = ct + .get_metadata("phi3.attention.layer_norm_rms_epsilon")? + .to_f32()? as f64; + let context_window = ct.get_metadata("phi3.context_length")?.to_u32()? as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QMatMul::from_qtensor(ct.tensor(reader, "output.weight", device)?)?; + let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?; + let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.ffn_up.weight"), - device, - )?)?; - let ffn_down = QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.ffn_down.weight"), - device, - )?)?; + let ffn_up = + QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?; + let ffn_down = + QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?; let mlp = Mlp { ffn_up, ffn_down, i_size, }; let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, rms_eps, )?; let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?, rms_eps, )?; layers.push(LayerWeights { - attn_qkv: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_qkv.weight"), - device, - )?)?, - attn_output: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_output.weight"), - device, - )?)?, + attn_qkv: QMatMul::from_qtensor( + ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?, + )?, + attn_output: QMatMul::from_qtensor( + ct.tensor(&format!("{prefix}.attn_output.weight"), device)?, + )?, attn_norm, ffn_norm, mlp, diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index a652a996a..a9f5d7843 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -369,7 +369,7 @@ impl Loader for GGMLLoader { revision, self, self.quantized_model_id, - self.quantized_filename, + Some(vec![self.quantized_filename.as_ref().unwrap().clone()]), silent ); self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index a08fdc09f..60ab71740 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -10,9 +10,9 @@ use super::{ }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; +use crate::gguf::Content; use crate::lora::Ordering; use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig}; -use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult}; use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; @@ -20,19 +20,18 @@ use crate::sequence::Sequence; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; -use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG}; use crate::{ + convert_gguf_to_hf_tokenizer, models::quantized_llama::ModelWeights as QLlama, models::quantized_phi2::ModelWeights as QPhi, models::quantized_phi3::ModelWeights as QPhi3, utils::tokens::get_token, xlora_models::{XLoraQLlama, XLoraQPhi3}, + GgufTokenizerConversion, }; +use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG}; use anyhow::{bail, Context, Result}; -use candle_core::quantized::{ - gguf_file::{self, Value as GgufValue}, - GgmlDType, -}; +use candle_core::quantized::GgmlDType; use candle_core::{DType, Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -72,7 +71,7 @@ pub struct GGUFLoader { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, xlora_order: Option, no_kv_cache: bool, @@ -81,9 +80,9 @@ pub struct GGUFLoader { tgt_non_granular_index: Option, } -#[derive(Debug, EnumString)] +#[derive(Copy, Clone, Debug, EnumString)] #[strum(serialize_all = "kebab-case")] -enum GGUFArchitecture { +pub enum GGUFArchitecture { Llama, Mpt, Gptneox, @@ -101,7 +100,8 @@ enum GGUFArchitecture { // - Case-insensitive variant matching (TODO: is this desirable?) // - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332 impl GGUFArchitecture { - fn from_value + std::fmt::Display>(value: T) -> Result { + /// GGUF architecture from a kebab-case representation. + pub fn from_value + std::fmt::Display>(value: T) -> Result { Self::from_str(&value.as_ref().to_ascii_lowercase()) .with_context(|| format!("Unknown GGUF architecture `{value}`")) .map_err(anyhow::Error::msg) @@ -120,7 +120,7 @@ pub struct GGUFLoaderBuilder { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, kind: ModelKind, xlora_order: Option, @@ -138,7 +138,7 @@ impl GGUFLoaderBuilder { chat_template: Option, tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, ) -> Self { let kind = ModelKind::Quantized { quant: QuantizationKind::Gguf, @@ -223,7 +223,7 @@ impl GGUFLoader { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, kind: ModelKind, xlora_order: Option, @@ -257,28 +257,6 @@ impl GGUFLoader { } } -fn parse_gguf_value(value: &GgufValue) -> String { - match value { - GgufValue::Array(vs) => vs - .iter() - .map(parse_gguf_value) - .collect::>() - .join(", "), - GgufValue::Bool(b) => b.to_string(), - GgufValue::F32(x) => x.to_string(), - GgufValue::F64(x) => x.to_string(), - GgufValue::I8(x) => x.to_string(), - GgufValue::I16(x) => x.to_string(), - GgufValue::I32(x) => x.to_string(), - GgufValue::I64(x) => x.to_string(), - GgufValue::String(x) => x.to_string(), - GgufValue::U8(x) => x.to_string(), - GgufValue::U16(x) => x.to_string(), - GgufValue::U32(x) => x.to_string(), - GgufValue::U64(x) => x.to_string(), - } -} - impl Loader for GGUFLoader { #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn load_model_from_hf( @@ -337,50 +315,25 @@ impl Loader for GGUFLoader { info!("Loading model `{}` on {device:?}...", self.get_id()); } - let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?; - let model = gguf_file::Content::read(&mut file) - .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?; - let arch = model.metadata["general.architecture"] - .to_string() - .context("Model metadata should have declared an architecture") - .and_then(GGUFArchitecture::from_value)?; - - info!("Model config:"); - let mut sorted_keys = model.metadata.keys().collect::>(); - sorted_keys.sort(); - for name in sorted_keys { - if !name.contains("tokenizer") { - let value = parse_gguf_value(&model.metadata[name]); - println!("{name}: {}", value); - } + let mut files = Vec::new(); + for weight_filename in paths.get_weight_filenames() { + files.push(std::fs::File::open(weight_filename)?); } + let mut files = files.iter_mut().collect::>(); - if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { - let mut tensors = Vec::new(); - for (name, info) in &model.tensor_infos { - tensors.push(format!( - "name = `{name}`, shape = {:?}, dtype = {:?}", - info.shape.clone(), - info.ggml_dtype - )); - } - fs::write( - "mistralrs_gguf_tensors.txt", - serde_json::to_string_pretty(&tensors).expect("Serialization failed."), - )?; - - info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); - } + let content = Content::from_readers(&mut files)?; + let arch = content.arch(); - let ConversionResult { + // Set bos/eos/unk to None to avoid the override + let GgufTokenizerConversion { tokenizer, bos, eos, unk, } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() { - convert_ggml_to_hf_tokenizer(&model)? + convert_gguf_to_hf_tokenizer(&content)? } else { - ConversionResult { + GgufTokenizerConversion { tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?, bos: None, eos: None, @@ -393,7 +346,7 @@ impl Loader for GGUFLoader { let model_config = { // Base config (quantization only): - let quant = ModelConfig::ParamsGGUF((model, &mut file).into(), (device, mapper).into()); + let quant = ModelConfig::ParamsGGUF(content, (device, mapper).into()); // With optional adapter config: let mut adapter = None; diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 9c4719fc7..250987be1 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -2,7 +2,6 @@ mod cache_manager; mod chat_template; mod ggml; mod gguf; -mod gguf_tokenizer; mod inputs_processor; mod isq; mod macros; @@ -21,7 +20,7 @@ use candle_core::quantized::GgmlDType; use chat_template::ChatTemplate; use core::fmt; pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig}; -pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig}; +pub use gguf::{GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig}; pub use isq::IsqModel; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; pub use normal_loaders::{ diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 0aea537d4..312e31ea7 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -2,7 +2,6 @@ use std::{ collections::HashMap, fs, path::{Path, PathBuf}, - str::FromStr, }; use anyhow::Result; @@ -251,14 +250,16 @@ pub fn get_model_paths( revision: String, token_source: &TokenSource, quantized_model_id: &Option, - quantized_filename: &Option, + quantized_filename: &Option>, api: &ApiRepo, model_id: &Path, ) -> Result> { match &quantized_filename { - Some(name) => match quantized_model_id.as_ref().unwrap().as_str() { - "" => Ok(vec![PathBuf::from_str(name).unwrap()]), - id => { + Some(names) => { + let id = quantized_model_id.as_ref().unwrap(); + let mut files = Vec::new(); + + for name in names { let qapi = ApiBuilder::new() .with_progress(true) .with_token(get_token(token_source)?) @@ -269,9 +270,10 @@ pub fn get_model_paths( revision.clone(), )); let model_id = Path::new(&id); - Ok(vec![api_get_file!(qapi, name, model_id)]) + files.push(api_get_file!(qapi, name, model_id)); } - }, + Ok(files) + } None => { let mut filenames = vec![]; for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { diff --git a/mistralrs-core/src/toml_selector.rs b/mistralrs-core/src/toml_selector.rs index 478d940eb..ae7a2d250 100644 --- a/mistralrs-core/src/toml_selector.rs +++ b/mistralrs-core/src/toml_selector.rs @@ -5,7 +5,7 @@ use serde::Deserialize; use crate::{ GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig, - SpeculativeLoader, + SpeculativeLoader, GGUF_MULTI_FILE_DELIMITER, }; fn default_repeat_last_n() -> usize { @@ -307,7 +307,10 @@ fn loader_from_selected( args.chat_template, Some(tok_model_id), quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .build(), TomlModelSelected::XLoraGGUF { @@ -324,7 +327,10 @@ fn loader_from_selected( args.chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .with_xlora( xlora_model_id, @@ -349,7 +355,10 @@ fn loader_from_selected( args.chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .with_lora( adapters_model_id, diff --git a/mistralrs-core/src/utils/max_seq_len.rs b/mistralrs-core/src/utils/max_seq_len.rs deleted file mode 100644 index 3ac96a29c..000000000 --- a/mistralrs-core/src/utils/max_seq_len.rs +++ /dev/null @@ -1,20 +0,0 @@ -use candle_core::{ - quantized::gguf_file::{Value, ValueType}, - Result, -}; -use tracing::warn; - -/// Extract a u32 or u8 max seq len. Warns if error and then uses a default -pub(crate) fn get_gguf_max_seq_len(max_seq_len: Result<&Value>, default: u64) -> u64 { - match max_seq_len { - Ok(m) => match m.value_type() { - ValueType::U32 => m.to_u32().unwrap() as u64, - ValueType::U64 => m.to_u64().unwrap(), - _ => default, - }, - Err(_) => { - warn!("GGUF file does not specify a context window, using {default}."); - default - } - } -} diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 4775afb3d..e5911b836 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod max_seq_len; pub(crate) mod model_config; pub(crate) mod progress; pub(crate) mod tokenizer; diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs index 7315f27b5..35f3aa11f 100644 --- a/mistralrs-core/src/utils/model_config.rs +++ b/mistralrs-core/src/utils/model_config.rs @@ -1,10 +1,11 @@ use super::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use anyhow::Result; -use candle_core::quantized::{ggml_file, gguf_file}; +use candle_core::quantized::ggml_file; use candle_nn::VarBuilder; use std::{collections::HashMap, path::PathBuf}; use crate::{ + gguf::Content, lora::{LoraConfig, Ordering}, pipeline::ModelPaths, xlora_models::XLoraConfig, @@ -17,12 +18,6 @@ pub struct FileGGML { pub gqa: usize, } -#[derive(derive_more::From)] -pub struct FileGGUF<'a> { - pub ct: gguf_file::Content, - pub reader: &'a mut std::fs::File, -} - #[derive(derive_more::From)] pub struct Device<'a> { pub device: &'a candle_core::Device, @@ -94,7 +89,7 @@ impl<'a> Adapter<'a> { // New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods: pub struct ParamsGGML(pub FileGGML); -pub struct ParamsGGUF<'a>(pub FileGGUF<'a>, pub Device<'a>); +pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(pub Content<'a, R>, pub Device<'a>); // A `None` type vs the `Some` type (`Adapter<'a>`) pub struct NoAdapter {} @@ -103,7 +98,7 @@ pub struct NoAdapter {} // (required workaround to support impl on subtypes, otherwise would use an enum) pub trait QuantParams {} impl QuantParams for ParamsGGML {} -impl QuantParams for ParamsGGUF<'_> {} +impl QuantParams for ParamsGGUF<'_, R> {} // Emulates `Option` but is compatible as a type bound in `impl` for Some vs None pub trait MaybeAdapter {} @@ -155,8 +150,7 @@ pub trait FromGGML { pub trait FromGGUF { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + ct: Content<'_, R>, device: &candle_core::Device, mapper: DeviceMapMetadata, ) -> Result @@ -181,8 +175,7 @@ pub trait FromAdapterGGML { pub trait FromAdapterGGUF { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + ct: Content<'_, R>, device: &candle_core::Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -232,20 +225,20 @@ impl Config> { } } -impl Config, NoAdapter> { +impl Config, NoAdapter> { pub fn try_into_model(self) -> Result { // Destructure props: - let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + let ParamsGGUF(content, Device { device, mapper }) = self.quant; // Forwards all structured fields above into the required flattened param sequence: - T::from_gguf(ct, reader, device, mapper) + T::from_gguf(content, device, mapper) } } -impl Config, Adapter<'_>> { +impl Config, Adapter<'_>> { pub fn try_into_model(self) -> Result { // Destructure props: - let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + let ParamsGGUF(content, Device { device, mapper }) = self.quant; let Adapter { xlora_config, @@ -257,8 +250,7 @@ impl Config, Adapter<'_>> { // Forwards all structured fields above into the required flattened param sequence: T::from_gguf( - ct, - reader, + content, device, lora_config, &vb, @@ -299,10 +291,10 @@ impl TryFrom> for XLoraQLlama { akin! { let &models_gguf = [QLlama, QPhi, QPhi3]; - impl TryFrom>> for *models_gguf { + impl TryFrom>> for *models_gguf { type Error = candle_core::Error; - fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_quantized("`Config` should be GGUF Quantized"); config.try_into_model() } @@ -312,10 +304,10 @@ akin! { akin! { let &models_gguf_a = [XLoraQLlama, XLoraQPhi3]; - impl TryFrom>> for *models_gguf_a { + impl TryFrom>> for *models_gguf_a { type Error = candle_core::Error; - fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter"); config.try_into_model() } diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index 8eef06eeb..dc98c6419 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -2,12 +2,12 @@ use std::collections::HashMap; +use crate::gguf::Content; use crate::lora::{ get_lora_cfg, AdapterSwapper, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear, }; -use crate::utils::max_seq_len::get_gguf_max_seq_len; +use candle_core::quantized::ggml_file; use candle_core::quantized::QMatMul; -use candle_core::quantized::{ggml_file, gguf_file}; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, RotaryEmbedding, VarBuilder}; use tqdm::Iter; @@ -447,8 +447,7 @@ impl ModelConfig::FromAdapterGGML for ModelWeights { impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -457,46 +456,48 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { mapper: DeviceMapMetadata, preload_adapters: &Option>, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; verify_sanity_gguf( - md_get("general.architecture")?.to_string().unwrap(), + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), "llama", )?; verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?; // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") + let n_expert = ct + .get_metadata("llama.expert_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") + let n_expert_used = ct + .get_metadata("llama.expert_used_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + let head_count = ct.get_metadata("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("llama.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = ct + .get_metadata("llama.attention.layer_norm_rms_epsilon")? + .to_f32()?; - let rope_freq_base = md_get("llama.rope.freq_base") + let rope_freq_base = ct + .get_metadata("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let head_dim = embedding_length / head_count; - let max_seq_len = - get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize; + let max_seq_len = ct + .get_metadata("llama.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = QRmsNorm::new( - ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, - )?; - let output = ct.tensor(reader, "output.weight", device)?; + let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?; + let output = ct.tensor("output.weight", device)?; let mut layers = Vec::with_capacity(block_count); let mut count = 0; let mapper = mapper.into_mapper(block_count, device)?; @@ -513,18 +514,14 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { DType::F32, )?; - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wo = - ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; - let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; let cfg_w1 = get_lora_cfg(&feed_forward_w1); let cfg_w2 = get_lora_cfg(&feed_forward_w2); let cfg_w3 = get_lora_cfg(&feed_forward_w3); @@ -562,15 +559,15 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?; let cfg_w1 = get_lora_cfg(&feed_forward_w1); let cfg_w2 = get_lora_cfg(&feed_forward_w2); let cfg_w3 = get_lora_cfg(&feed_forward_w3); @@ -613,9 +610,8 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { experts, } }; - let attention_norm = - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?; let cfgq = get_lora_cfg(&attention_wq); let cfgk = get_lora_cfg(&attention_wk); let cfgv = get_lora_cfg(&attention_wv); diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 248bc4175..44feddc34 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::repeat_kv; use crate::layers::verify_sanity_gguf; use crate::layers::CausalMasker; @@ -18,7 +19,6 @@ use crate::lora::Ordering; use crate::lora::QLoraLinear; use crate::pipeline::extract_logits; use crate::DeviceMapMetadata; -use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; @@ -216,8 +216,7 @@ fn precomput_freqs_cis( impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -226,36 +225,39 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { mapper: DeviceMapMetadata, preload_adapters: &Option>, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?; + verify_sanity_gguf( + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), + "phi3", + )?; verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?; // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let context_window = md_get("phi3.context_length")?.to_u32()? as usize; + let head_count = ct.get_metadata("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi3.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi3.embedding_length")?.to_u32()? as usize; + let i_size = ct.get_metadata("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = ct + .get_metadata("phi3.attention.layer_norm_rms_epsilon")? + .to_f32()? as f64; + let context_window = ct.get_metadata("phi3.context_length")?.to_u32()? as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QMatMul::from_qtensor(ct.tensor(reader, "output.weight", device)?)?; + let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?; + let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; let mut count = 0; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; - let ffn_down = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; + let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; let cfg_up = get_lora_cfg(&ffn_up); let cfg_down = get_lora_cfg(&ffn_down); let mlp = Mlp { @@ -282,15 +284,15 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { i_size, }; let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, rms_eps, )?; let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?, rms_eps, )?; - let qkv = ct.tensor(reader, &format!("{prefix}.attn_qkv.weight"), device)?; - let output = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?; + let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let cfg_qkv = get_lora_cfg(&qkv); let cfg_out = get_lora_cfg(&output); layers.push(LayerWeights { diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index 359ac00e8..ebf49b824 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -9,7 +9,7 @@ These are API docs for the `mistralrs` package. ## `Which` -Each `*_model_id` may be a HF hub repo or a local path. +Each `*_model_id` may be a HF hub repo or a local path. For quantized GGUF models, a list is accepted if multiples files must be specified. Additionally, for models without quantization, the model architecture should be provided as the `arch` parameter in contrast to GGUF models which encode the architecture in the file. It should be one of the following: - `mistral` @@ -49,13 +49,13 @@ class Which(Enum): class GGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] repeat_last_n: int = 64 @dataclass class XLoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] xlora_model_id: str order: str tgt_non_granular_index: int | None = None @@ -64,7 +64,7 @@ class Which(Enum): class LoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] adapters_model_id: str order: str repeat_last_n: int = 64 diff --git a/mistralrs-pyo3/README.md b/mistralrs-pyo3/README.md index 1d57af08b..5e058da37 100644 --- a/mistralrs-pyo3/README.md +++ b/mistralrs-pyo3/README.md @@ -115,18 +115,18 @@ We also provide [a cookbook here](../examples/python/cookbook.ipynb)! ## Example ```python -from mistralrs import ModelKind, MistralLoader, ChatCompletionRequest - -kind = ModelKind.QuantizedGGUF -loader = MistralLoader( - model_id="mistralai/Mistral-7B-Instruct-v0.1", - kind=kind, - no_kv_cache=False, - repeat_last_n=64, - quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", - quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", +from mistralrs import Runner, Which, ChatCompletionRequest + +runner = Runner( + which=Which.GGUF( + tok_model_id="mistralai/Mistral-7B-Instruct-v0.1", + quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", + quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", + tokenizer_json=None, + repeat_last_n=64, + ) ) -runner = loader.load() + res = runner.send_chat_completion_request( ChatCompletionRequest( model="mistral", @@ -134,10 +134,12 @@ res = runner.send_chat_completion_request( {"role": "user", "content": "Tell me a story about the Rust type system."} ], max_tokens=256, - frequency_penalty=1.0, + presence_penalty=1.0, top_p=0.1, temperature=0.1, ) ) -print(res) +print(res.choices[0].message.content) +print(res.usage) + ``` \ No newline at end of file diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index f1d7c46c7..a307bfd6f 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -95,13 +95,13 @@ class Which(Enum): class GGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] repeat_last_n: int = 64 @dataclass class XLoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] xlora_model_id: str order: str tgt_non_granular_index: int | None = None @@ -110,7 +110,7 @@ class Which(Enum): class LoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] adapters_model_id: str order: str repeat_last_n: int = 64 diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 2f15d8ac5..dd6918dd2 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -177,7 +177,9 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .map_left(|file| vec![file]) + .unwrap_right(), ) .build(), Which::XLoraGGUF { @@ -195,7 +197,9 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .map_left(|file| vec![file]) + .unwrap_right(), ) .with_xlora( xlora_model_id, @@ -222,7 +226,9 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .map_left(|file| vec![file]) + .unwrap_right(), ) .with_lora( adapters_model_id, diff --git a/mistralrs-pyo3/src/which.rs b/mistralrs-pyo3/src/which.rs index a5a33a612..0b121d608 100644 --- a/mistralrs-pyo3/src/which.rs +++ b/mistralrs-pyo3/src/which.rs @@ -1,3 +1,4 @@ +use either::Either; use mistralrs_core::NormalLoaderType; use pyo3::pyclass; @@ -58,14 +59,14 @@ pub enum Which { GGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, }, XLoraGGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, xlora_model_id: String, order: String, @@ -75,7 +76,7 @@ pub enum Which { LoraGGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, adapters_model_id: String, order: String, diff --git a/mistralrs/examples/gguf_locally/main.rs b/mistralrs/examples/gguf_locally/main.rs index b04fc9fa5..e38e7e8c4 100644 --- a/mistralrs/examples/gguf_locally/main.rs +++ b/mistralrs/examples/gguf_locally/main.rs @@ -17,7 +17,7 @@ fn setup() -> anyhow::Result> { Some("chat_templates/mistral.json".to_string()), None, ".".to_string(), - "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()], ) .build(); // Load, into a Pipeline diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index b6539edaf..076ad9978 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -15,7 +15,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), "TheBloke/Mistral-7B-Instruct-v0.1-GGUF".to_string(), - "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()], ) .build(); // Load, into a Pipeline From 19ca7acf39ada6a1640d1ab0217bdcf65e1106bb Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:19:13 -0400 Subject: [PATCH 2/9] Organize normal loading metadata (#381) * Organize normal loading metadata * Fix --- mistralrs-core/src/models/gemma.rs | 19 +-- mistralrs-core/src/models/llama.rs | 21 +-- mistralrs-core/src/models/mistral.rs | 35 ++--- mistralrs-core/src/models/mixtral.rs | 21 +-- mistralrs-core/src/models/phi2.rs | 21 +-- mistralrs-core/src/models/phi3.rs | 21 +-- mistralrs-core/src/models/qwen2.rs | 21 +-- mistralrs-core/src/pipeline/macros.rs | 24 ++-- mistralrs-core/src/pipeline/mod.rs | 4 +- mistralrs-core/src/pipeline/normal_loaders.rs | 126 ++++++------------ mistralrs-core/src/xlora_models/gemma.rs | 19 +-- mistralrs-core/src/xlora_models/llama.rs | 21 +-- mistralrs-core/src/xlora_models/mistral.rs | 21 +-- mistralrs-core/src/xlora_models/mixtral.rs | 21 +-- mistralrs-core/src/xlora_models/phi2.rs | 21 +-- mistralrs-core/src/xlora_models/phi3.rs | 21 +-- 16 files changed, 199 insertions(+), 238 deletions(-) diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index a74b1558c..89871580f 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -8,8 +8,7 @@ use candle_nn::{linear_b as linear, Activation, RotaryEmbedding, VarBuilder}; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, QLinear, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; fn default_max_position_embeddings() -> usize { @@ -319,11 +318,11 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding( cfg.vocab_size, @@ -337,7 +336,9 @@ impl Model { cfg.rope_theta as f32, cfg.head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -347,7 +348,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -362,7 +363,7 @@ impl Model { layers, norm, lm_head, - device: real_device, + device: normal_loading_metadata.real_device, hidden_size: cfg.hidden_size, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: default_max_position_embeddings(), diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 55554a467..ab2b1df5a 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -10,8 +10,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, Deserialize)] @@ -294,11 +293,11 @@ impl Llama { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let wte = embedding( cfg.vocab_size, cfg.hidden_size, @@ -307,7 +306,7 @@ impl Llama { let lm_head = linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; let ln_f = RmsNorm::new( cfg.hidden_size, @@ -322,7 +321,9 @@ impl Llama { cfg.rope_theta, head_dim, cfg.max_position_embeddings, - mapper.device_for(i, false).unwrap_or(&real_device), + mapper + .device_for(i, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), ) @@ -333,7 +334,7 @@ impl Llama { cfg, &*mapper, i, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, ) .expect("Failed to load block.") @@ -346,7 +347,7 @@ impl Llama { ln_f, lm_head: QMatMul::Tensor(lm_head.weight().clone()), kv_cache: crate::pipeline::Cache::new(cfg.num_hidden_layers, false), - device: real_device, + device: normal_loading_metadata.real_device, mapper, }) } diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index d7f2c8b09..f41dd11c1 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -8,8 +8,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, PartialEq)] @@ -280,21 +279,11 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); let vb_lm_head = vb.pp("lm_head"); - Self::new_inner( - cfg, - vb_m, - vb_lm_head, - is_gptx, - mapper, - loading_isq, - real_device, - ) + Self::new_inner(cfg, vb_m, vb_lm_head, is_gptx, normal_loading_metadata) } pub fn new_inner( @@ -302,11 +291,11 @@ impl Model { vb_m: VarBuilder, vb_lm_head: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -320,7 +309,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb_m.dtype(), )?); @@ -330,7 +321,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -342,7 +333,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb_lm_head, loading_isq), + mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -350,7 +341,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 11cb33294..0299037d0 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -11,8 +11,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; /// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 @@ -383,12 +382,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -402,7 +401,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -412,7 +413,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -424,7 +425,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -432,7 +433,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index d1a8a89cd..187e1fcf8 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -14,8 +14,7 @@ use serde::Deserialize; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, QLinear, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; // https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py @@ -287,12 +286,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = embedding( cfg.vocab_size, cfg.hidden_size, @@ -312,7 +311,9 @@ impl Model { cfg.head_dim(), (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?; @@ -321,7 +322,7 @@ impl Model { vb_m.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, )?; layers.push(layer) @@ -329,7 +330,7 @@ impl Model { let lm_head = linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -337,7 +338,7 @@ impl Model { final_layernorm, lm_head: QLinear::from_linear(lm_head), cache: Cache::new(cfg.num_hidden_layers, false), - device: real_device, + device: normal_loading_metadata.real_device, max_seq_len: cfg.max_position_embeddings, mapper, }) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index c00e60ea7..1fbc1218b 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -13,8 +13,7 @@ use crate::{ repeat_kv, CausalMasker, MatMul, PhiRopeConfig, PhiRotaryEmbedding, RmsNorm, ScaledDotProductAttention, }, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json @@ -291,12 +290,12 @@ impl Model { cfg: &Config, vb: VarBuilder, _is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -308,7 +307,9 @@ impl Model { let rotary_emb = Arc::new(PhiRotaryEmbedding::new( vb.dtype(), cfg.clone(), - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), )?); let layer = DecoderLayer::new( rotary_emb.clone(), @@ -316,7 +317,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -328,14 +329,14 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, layers, norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index eae06e4ab..a6103fb50 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -7,8 +7,7 @@ use std::sync::Arc; use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, MatMul, QLinear, RmsNorm, ScaledDotProductAttention}, - pipeline::{extract_logits, Cache, IsqModel, NormalModel}, - DeviceMapMetadata, + pipeline::{extract_logits, Cache, IsqModel, NormalLoadingMetadata, NormalModel}, }; #[derive(Debug, Clone, PartialEq, serde::Deserialize)] @@ -269,12 +268,12 @@ impl Model { cfg: &Config, vb: VarBuilder, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -288,7 +287,9 @@ impl Model { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -298,7 +299,7 @@ impl Model { vb_l.pp(layer_idx), &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, )?; layers.push(layer) } @@ -310,7 +311,7 @@ impl Model { let lm_head = linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -318,7 +319,7 @@ impl Model { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, mapper, diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 46de13167..d3795c413 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -314,9 +314,11 @@ macro_rules! normal_model_loader { &$config, $use_flash_attn, vb, - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, )? }}; } @@ -372,9 +374,11 @@ macro_rules! xlora_model_loader { $paths.get_adapter_configs().as_ref().unwrap(), Some($paths.get_classifier_config().as_ref().unwrap().clone()), $paths.get_ordering().as_ref().unwrap().clone(), - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, &$crate::utils::varbuilder_utils::load_preload_adapters( $paths.get_lora_preload_adapter_info(), $dtype.unwrap_or($default_dtype), @@ -413,9 +417,11 @@ macro_rules! lora_model_loader { $paths.get_adapter_configs().as_ref().unwrap(), None, $paths.get_ordering().as_ref().unwrap().clone(), - $mapper, - $loading_isq, - $real_device, + $crate::pipeline::NormalLoadingMetadata { + mapper: $mapper, + loading_isq: $loading_isq, + real_device: $real_device, + }, &$crate::utils::varbuilder_utils::load_preload_adapters( $paths.get_lora_preload_adapter_info(), $dtype.unwrap_or($default_dtype), diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 250987be1..262ec0663 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -24,8 +24,8 @@ pub use gguf::{GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConf pub use isq::IsqModel; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; pub use normal_loaders::{ - GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, NormalModelLoader, - Phi2Loader, Phi3Loader, Qwen2Loader, + GemmaLoader, LlamaLoader, MistralLoader, MixtralLoader, NormalLoaderType, + NormalLoadingMetadata, NormalModelLoader, Phi2Loader, Phi3Loader, Qwen2Loader, }; pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths, XLoraPaths}; pub(crate) use processing::{BasicProcessor, Processor, ProcessorCreator}; diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs index 09e0140a9..5237834f2 100644 --- a/mistralrs-core/src/pipeline/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/normal_loaders.rs @@ -11,15 +11,23 @@ use pyo3::pyclass; use serde::Deserialize; +/// Metadata for loading a model with ISQ or device mapping. +pub struct NormalLoadingMetadata { + // Device mapping metadata which can be used to construct a concrete device mapper + pub mapper: DeviceMapMetadata, + // Flag to check if loading in ISQ + pub loading_isq: bool, + // Device mapping target device (the one that is not the cpu) + pub real_device: Device, +} + pub trait NormalModelLoader { fn load( &self, config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result>; #[allow(clippy::too_many_arguments)] fn load_xlora( @@ -30,9 +38,7 @@ pub trait NormalModelLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result>; fn is_gptx(&self) -> bool; @@ -127,17 +133,13 @@ impl NormalModelLoader for MistralLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::mistral::Model::new( &MistralBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -148,9 +150,7 @@ impl NormalModelLoader for MistralLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraMistral::new( @@ -160,9 +160,7 @@ impl NormalModelLoader for MistralLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -233,17 +231,13 @@ impl NormalModelLoader for GemmaLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::gemma::Model::new( &GemmaBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -254,9 +248,7 @@ impl NormalModelLoader for GemmaLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraGemma::new( @@ -266,9 +258,7 @@ impl NormalModelLoader for GemmaLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -331,17 +321,13 @@ impl NormalModelLoader for LlamaLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::llama::Llama::new( &LlamaBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -352,9 +338,7 @@ impl NormalModelLoader for LlamaLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraLlama::new( @@ -364,9 +348,7 @@ impl NormalModelLoader for LlamaLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -430,17 +412,13 @@ impl NormalModelLoader for MixtralLoader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::mixtral::Model::new( &MixtralBasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -451,9 +429,7 @@ impl NormalModelLoader for MixtralLoader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraMixtral::new( @@ -463,9 +439,7 @@ impl NormalModelLoader for MixtralLoader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -529,17 +503,13 @@ impl NormalModelLoader for Phi2Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::phi2::Model::new( &Phi2BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -550,9 +520,7 @@ impl NormalModelLoader for Phi2Loader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraPhi2::new( @@ -562,9 +530,7 @@ impl NormalModelLoader for Phi2Loader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -639,17 +605,13 @@ impl NormalModelLoader for Phi3Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::phi3::Model::new( &Phi3BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -660,9 +622,7 @@ impl NormalModelLoader for Phi3Loader { lora_config: &[((String, String), LoraConfig)], xlora_config: Option, xlora_ordering: Ordering, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result> { Ok(Box::new(xlora_models::XLoraPhi3::new( @@ -672,9 +632,7 @@ impl NormalModelLoader for Phi3Loader { xlora_config, xlora_ordering, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, preload_adapters, )?)) } @@ -740,17 +698,13 @@ impl NormalModelLoader for Qwen2Loader { config: &str, use_flash_attn: bool, vb: VarBuilder, - mapper: DeviceMapMetadata, - loading_isq: bool, - device: Device, + normal_loading_metadata: NormalLoadingMetadata, ) -> Result> { Ok(Box::new(models::qwen2::Model::new( &Qwen2BasicConfig::deserialize(config, use_flash_attn)?, vb, self.is_gptx(), - mapper, - loading_isq, - device, + normal_loading_metadata, )?)) } fn load_xlora( @@ -761,9 +715,7 @@ impl NormalModelLoader for Qwen2Loader { _lora_config: &[((String, String), LoraConfig)], _xlora_config: Option, _xlora_ordering: Ordering, - _mapper: DeviceMapMetadata, - _loading_isq: bool, - _device: Device, + _normal_loading_metadata: NormalLoadingMetadata, _preload_adapters: &Option>, ) -> Result> { todo!() diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index bde161cee..57bce3349 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ layers::ScaledDotProductAttention, lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; use candle_nn::{RotaryEmbedding, VarBuilder}; @@ -17,7 +17,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear}, models::gemma::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -471,13 +470,13 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -491,7 +490,9 @@ impl XLoraModel { cfg.rope_theta as f32, cfg.head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -504,7 +505,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -548,7 +549,7 @@ impl XLoraModel { layers, norm, lm_head: QLinear::from_linear(lm_head), - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), hidden_size: cfg.hidden_size, cache: Cache::new(cfg.num_hidden_layers, true), diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index db1e6f92d..55046a40b 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -15,8 +15,7 @@ use crate::{ device_map::DeviceMapper, layers::{repeat_kv, CausalMasker, QLinear, RmsNorm}, models::llama::Config, - pipeline::{self, extract_logits, LayerCaches, NormalModel}, - DeviceMapMetadata, + pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel}, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -552,13 +551,13 @@ impl XLoraLlama { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let dtype = vb.dtype(); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let wte = embedding( cfg.vocab_size, cfg.hidden_size, @@ -567,7 +566,7 @@ impl XLoraLlama { let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; let ln_f = RmsNorm::new( cfg.hidden_size, @@ -583,7 +582,9 @@ impl XLoraLlama { cfg.rope_theta, head_dim, cfg.max_position_embeddings, - mapper.device_for(i, false).unwrap_or(&real_device), + mapper + .device_for(i, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), ) @@ -597,7 +598,7 @@ impl XLoraLlama { &xlora_ordering, &*mapper, i, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, preload_adapters, ) @@ -639,7 +640,7 @@ impl XLoraLlama { ln_f, lm_head: QLinear::from_linear(lm_head), kv_cache: pipeline::Cache::new(cfg.num_hidden_layers, true), - device: real_device, + device: normal_loading_metadata.real_device, xlora_classifier: xlora_config.map(|xlora_config| { XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap() }), diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index 4d240a7e5..620cd6fc6 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -3,7 +3,7 @@ use crate::{ layers::ScaledDotProductAttention, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor}; @@ -17,7 +17,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear, RmsNorm}, models::mistral::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker}; @@ -438,12 +437,12 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding( cfg.vocab_size, @@ -459,7 +458,9 @@ impl XLoraModel { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -472,7 +473,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -513,7 +514,7 @@ impl XLoraModel { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -521,7 +522,7 @@ impl XLoraModel { norm, lm_head: QLinear::from_linear(lm_head), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, true), max_seq_len: cfg.max_position_embeddings, diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 65e715dfe..223bedd72 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -3,7 +3,7 @@ use crate::{ layers::{MatMul, ScaledDotProductAttention}, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py @@ -19,7 +19,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, RmsNorm}, models::mixtral::Config, pipeline::{extract_logits, Cache, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -576,13 +575,13 @@ impl XLoraModel { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -597,7 +596,9 @@ impl XLoraModel { cfg.rope_theta as f32, head_dim, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?); @@ -610,7 +611,7 @@ impl XLoraModel { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -650,7 +651,7 @@ impl XLoraModel { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -658,7 +659,7 @@ impl XLoraModel { norm, lm_head: QMatMul::Tensor(lm_head.weight().clone()), sliding_window: cfg.sliding_window, - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, false), max_seq_len: cfg.max_position_embeddings, diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index 327ca8fde..a64852274 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ layers::ScaledDotProductAttention, lora::{linear, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; /// Phi model. /// https://huggingface.co/microsoft/phi-2 @@ -24,7 +24,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, QLinear}, models::phi2::Config, pipeline::{extract_logits, NormalModel}, - DeviceMapMetadata, }; use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig}; @@ -427,13 +426,13 @@ impl Model { xlora_config: Option, xlora_ordering: Ordering, is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = embedding( cfg.vocab_size, cfg.hidden_size, @@ -454,7 +453,9 @@ impl Model { cfg.head_dim(), (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize, cfg.max_position_embeddings, - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), is_gptx, vb.dtype(), )?; @@ -466,7 +467,7 @@ impl Model { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, rotary_emb, preload_adapters, )?; @@ -496,7 +497,7 @@ impl Model { let lm_head = candle_nn::linear( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, @@ -504,7 +505,7 @@ impl Model { final_layernorm, lm_head: QLinear::from_linear(lm_head), cache: Cache::new(cfg.num_hidden_layers, true), - device: real_device, + device: normal_loading_metadata.real_device, max_seq_len: cfg.max_position_embeddings, dtype: vb.dtype(), xlora_classifier: xlora_config.map(|xlora_config| { diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 1fabcf286..875453140 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -5,7 +5,7 @@ use crate::{ layers::ScaledDotProductAttention, lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering}, - pipeline::IsqModel, + pipeline::{IsqModel, NormalLoadingMetadata}, }; use candle_core::{quantized::QMatMul, DType, Device, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -18,7 +18,6 @@ use crate::{ layers::{repeat_kv, CausalMasker, PhiRotaryEmbedding, QLinear, RmsNorm}, models::phi3::Config, pipeline::{extract_logits, NormalModel}, - DeviceMapMetadata, }; use crate::pipeline::Cache; @@ -387,13 +386,13 @@ impl Model { xlora_config: Option, xlora_ordering: Ordering, _is_gptx: bool, - mapper: DeviceMapMetadata, - loading_isq: bool, - real_device: Device, + normal_loading_metadata: NormalLoadingMetadata, preload_adapters: &Option>, ) -> Result { let vb_m = vb.pp("model"); - let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?; + let mapper = normal_loading_metadata + .mapper + .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?; let embed_tokens = candle_nn::embedding( cfg.vocab_size, cfg.hidden_size, @@ -406,7 +405,9 @@ impl Model { let rotary_emb = Arc::new(PhiRotaryEmbedding::new( vb.dtype(), cfg.clone(), - mapper.device_for(layer_idx, false).unwrap_or(&real_device), + mapper + .device_for(layer_idx, false) + .unwrap_or(&normal_loading_metadata.real_device), )?); let layer = DecoderLayer::new( rotary_emb.clone(), @@ -417,7 +418,7 @@ impl Model { &xlora_ordering, &*mapper, layer_idx, - loading_isq, + normal_loading_metadata.loading_isq, preload_adapters, )?; layers.push(layer) @@ -449,14 +450,14 @@ impl Model { let lm_head = candle_nn::linear_no_bias( cfg.hidden_size, cfg.vocab_size, - mapper.set_nm_device(vb.pp("lm_head"), loading_isq), + mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq), )?; Ok(Self { embed_tokens, layers, norm, lm_head: QLinear::from_linear(lm_head), - device: real_device, + device: normal_loading_metadata.real_device, dtype: vb.dtype(), cache: Cache::new(cfg.num_hidden_layers, true), max_seq_len: cfg.max_position_embeddings, From 818808b0a3d032493a87aa783337dd021e938982 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:23:02 -0400 Subject: [PATCH 3/9] Bump version 0.1.13 -> 0.1.14 (#382) --- Cargo.toml | 2 +- mistralrs-bench/Cargo.toml | 2 +- mistralrs-pyo3/Cargo.toml | 2 +- mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-pyo3/pyproject.toml | 2 +- mistralrs-pyo3/pyproject_template.toml | 2 +- mistralrs-server/Cargo.toml | 2 +- mistralrs/Cargo.toml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 280bdbf47..fd7e341f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.13" +version = "0.1.14" edition = "2021" description = "Fast and easy LLM serving." homepage = "https://github.com/EricLBuehler/mistral.rs" diff --git a/mistralrs-bench/Cargo.toml b/mistralrs-bench/Cargo.toml index 60b41bddf..66eb1a2b8 100644 --- a/mistralrs-bench/Cargo.toml +++ b/mistralrs-bench/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true serde.workspace = true serde_json.workspace = true clap.workspace = true -mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } tracing.workspace = true either.workspace = true tokio.workspace = true diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index 574d7b1c7..7b61427c7 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features = ["pyo3_macros"] } +mistralrs-core = { version = "0.1.14", path = "../mistralrs-core", features = ["pyo3_macros"] } serde.workspace = true serde_json.workspace = true candle-core.workspace = true diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index d7be7710c..3b3cb48b7 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.13", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } +mistralrs-core = { version = "0.1.14", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", features=["$feature_name"] } diff --git a/mistralrs-pyo3/pyproject.toml b/mistralrs-pyo3/pyproject.toml index e684a4091..094b9dcd2 100644 --- a/mistralrs-pyo3/pyproject.toml +++ b/mistralrs-pyo3/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "mistralrs" -version = "0.1.13" +version = "0.1.14" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-pyo3/pyproject_template.toml b/mistralrs-pyo3/pyproject_template.toml index 3fa3ad241..7ba27d1a1 100644 --- a/mistralrs-pyo3/pyproject_template.toml +++ b/mistralrs-pyo3/pyproject_template.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "$name" -version = "0.1.13" +version = "0.1.14" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml index 4355a15a2..6a07e9a14 100644 --- a/mistralrs-server/Cargo.toml +++ b/mistralrs-server/Cargo.toml @@ -22,7 +22,7 @@ axum = { version = "0.7.4", features = ["tokio"] } tower-http = { version = "0.5.1", features = ["cors"]} utoipa = { version = "4.2", features = ["axum_extras"] } utoipa-swagger-ui = { version = "7.1.0", features = ["axum"]} -mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } dyn-fmt = "0.4.0" indexmap.workspace = true accelerate-src = { workspace = true, optional = true } diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index e0134537f..a282d05f5 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true homepage.workspace = true [dependencies] -mistralrs-core = { version = "0.1.13", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } anyhow.workspace = true tokio.workspace = true candle-core.workspace = true From 9712da61e89b828925b6e6993a94be2c13a9c09f Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 4 Jun 2024 22:14:54 -0400 Subject: [PATCH 4/9] Patch incorrect unwrap and bump version (#383) * Patch incorrect unwrap * Bump version to 0.1.15 --- Cargo.toml | 2 +- mistralrs-bench/Cargo.toml | 2 +- mistralrs-pyo3/Cargo.toml | 2 +- mistralrs-pyo3/Cargo_template.toml | 2 +- mistralrs-pyo3/pyproject.toml | 2 +- mistralrs-pyo3/pyproject_template.toml | 2 +- mistralrs-pyo3/src/lib.rs | 21 ++++++++++++--------- mistralrs-server/Cargo.toml | 2 +- mistralrs/Cargo.toml | 2 +- 9 files changed, 20 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fd7e341f3..f5d322a42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.14" +version = "0.1.15" edition = "2021" description = "Fast and easy LLM serving." homepage = "https://github.com/EricLBuehler/mistral.rs" diff --git a/mistralrs-bench/Cargo.toml b/mistralrs-bench/Cargo.toml index 66eb1a2b8..0c4ebdb80 100644 --- a/mistralrs-bench/Cargo.toml +++ b/mistralrs-bench/Cargo.toml @@ -17,7 +17,7 @@ candle-core.workspace = true serde.workspace = true serde_json.workspace = true clap.workspace = true -mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.15", path = "../mistralrs-core" } tracing.workspace = true either.workspace = true tokio.workspace = true diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index 7b61427c7..aed0a877a 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.14", path = "../mistralrs-core", features = ["pyo3_macros"] } +mistralrs-core = { version = "0.1.15", path = "../mistralrs-core", features = ["pyo3_macros"] } serde.workspace = true serde_json.workspace = true candle-core.workspace = true diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml index 3b3cb48b7..683239db7 100644 --- a/mistralrs-pyo3/Cargo_template.toml +++ b/mistralrs-pyo3/Cargo_template.toml @@ -17,7 +17,7 @@ doc = false [dependencies] pyo3.workspace = true -mistralrs-core = { version = "0.1.14", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } +mistralrs-core = { version = "0.1.15", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] } serde.workspace = true serde_json.workspace = true candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", features=["$feature_name"] } diff --git a/mistralrs-pyo3/pyproject.toml b/mistralrs-pyo3/pyproject.toml index 094b9dcd2..8b8339cfc 100644 --- a/mistralrs-pyo3/pyproject.toml +++ b/mistralrs-pyo3/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "mistralrs" -version = "0.1.14" +version = "0.1.15" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-pyo3/pyproject_template.toml b/mistralrs-pyo3/pyproject_template.toml index 7ba27d1a1..2ec480ef4 100644 --- a/mistralrs-pyo3/pyproject_template.toml +++ b/mistralrs-pyo3/pyproject_template.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "$name" -version = "0.1.14" +version = "0.1.15" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index dd6918dd2..1c9ff3916 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -177,9 +177,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename - .map_left(|file| vec![file]) - .unwrap_right(), + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .build(), Which::XLoraGGUF { @@ -197,9 +198,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename - .map_left(|file| vec![file]) - .unwrap_right(), + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .with_xlora( xlora_model_id, @@ -226,9 +228,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename - .map_left(|file| vec![file]) - .unwrap_right(), + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .with_lora( adapters_model_id, diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml index 6a07e9a14..7616238f3 100644 --- a/mistralrs-server/Cargo.toml +++ b/mistralrs-server/Cargo.toml @@ -22,7 +22,7 @@ axum = { version = "0.7.4", features = ["tokio"] } tower-http = { version = "0.5.1", features = ["cors"]} utoipa = { version = "4.2", features = ["axum_extras"] } utoipa-swagger-ui = { version = "7.1.0", features = ["axum"]} -mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.15", path = "../mistralrs-core" } dyn-fmt = "0.4.0" indexmap.workspace = true accelerate-src = { workspace = true, optional = true } diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index a282d05f5..fa1246106 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true homepage.workspace = true [dependencies] -mistralrs-core = { version = "0.1.14", path = "../mistralrs-core" } +mistralrs-core = { version = "0.1.15", path = "../mistralrs-core" } anyhow.workspace = true tokio.workspace = true candle-core.workspace = true From 798adb4280e59d23d66c35fe2f4958886d300327 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 5 Jun 2024 05:33:20 -0400 Subject: [PATCH 5/9] More verbose logging during loading (#385) * More verbose logging when loading * More logging --- mistralrs-core/src/lib.rs | 4 ++-- mistralrs-core/src/pipeline/macros.rs | 20 +++++++++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index da1757ab1..0be8760b5 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -154,7 +154,7 @@ fn set_gemm_reduced_precision_f16() { let a = Tensor::zeros((2, 2), DType::BF16, &Device::new_cuda(0).unwrap()).unwrap(); candle_core::cuda::set_gemm_reduced_precision_bf16(true); match a.matmul(&a) { - Ok(_) => (), + Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."), Err(e) => { if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { tracing::info!("GEMM reduced precision in BF16 not supported."); @@ -167,7 +167,7 @@ fn set_gemm_reduced_precision_f16() { let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap(); candle_core::cuda::set_gemm_reduced_precision_f16(true); match a.matmul(&a) { - Ok(_) => (), + Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."), Err(e) => { if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { tracing::info!("GEMM reduced precision in F16 not supported."); diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index d3795c413..37a9ce4eb 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -92,9 +92,11 @@ macro_rules! get_paths { info!("Using tokenizer.json at `{p}`"); PathBuf::from_str(p)? } else { + info!("Loading `tokenizer.json` at `{}`", $this.model_id); $crate::api_get_file!(api, "tokenizer.json", model_id) }; + info!("Loading `config.json` at `{}`", $this.model_id); let config_filename = $crate::api_get_file!(api, "config.json", model_id); let filenames = get_model_paths( @@ -125,6 +127,7 @@ macro_rules! get_paths { .collect::>() .contains(&"generation_config.json".to_string()) { + info!("Loading `generation_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, "generation_config.json", @@ -138,6 +141,7 @@ macro_rules! get_paths { .collect::>() .contains(&"preprocessor_config.json".to_string()) { + info!("Loading `preprocessor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, "preprocessor_config.json", @@ -151,6 +155,7 @@ macro_rules! get_paths { .collect::>() .contains(&"processor_config.json".to_string()) { + info!("Loading `processor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, "processor_config.json", @@ -160,6 +165,7 @@ macro_rules! get_paths { None }; + info!("Loading `tokenizer_config.json` at `{}`", $this.model_id); let template_filename = $crate::api_get_file!(api, "tokenizer_config.json", model_id); Ok(Box::new($path_name { @@ -188,14 +194,13 @@ macro_rules! get_paths_gguf { .with_token(get_token($token_source)?) .build()?; let revision = $revision.unwrap_or("main".to_string()); - let model_id_this = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone()); - let model_id_copy = model_id_this.clone(); + let this_model_id = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone()); let api = api.repo(Repo::with_revision( - model_id_this.clone(), + this_model_id.clone(), RepoType::Model, revision.clone(), )); - let model_id = std::path::Path::new(&model_id_copy); + let model_id = std::path::Path::new(&this_model_id); let chat_template = if let Some(ref p) = $this.chat_template { if p.ends_with(".json") { @@ -205,6 +210,7 @@ macro_rules! get_paths_gguf { PathBuf::from_str("")? } } else { + info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id); $crate::api_get_file!( api, "tokenizer_config.json", @@ -229,7 +235,7 @@ macro_rules! get_paths_gguf { xlora_config, lora_preload_adapter_info, } = get_xlora_paths( - model_id_this, + this_model_id.clone(), &$this.xlora_model_id, &$token_source, revision.clone(), @@ -240,6 +246,7 @@ macro_rules! get_paths_gguf { .collect::>() .contains(&"generation_config.json".to_string()) { + info!("Loading `generation_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, "generation_config.json", @@ -253,6 +260,7 @@ macro_rules! get_paths_gguf { .collect::>() .contains(&"preprocessor_config.json".to_string()) { + info!("Loading `preprocessor_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, "preprocessor_config.json", @@ -266,6 +274,7 @@ macro_rules! get_paths_gguf { .collect::>() .contains(&"processor_config.json".to_string()) { + info!("Loading `processor_config.json` at `{}`", this_model_id); Some($crate::api_get_file!( api, "processor_config.json", @@ -276,6 +285,7 @@ macro_rules! get_paths_gguf { }; let tokenizer_filename = if $this.model_id.is_some() { + info!("Loading `tokenizer.json` at `{}`", this_model_id); $crate::api_get_file!(api, "tokenizer.json", model_id) } else { PathBuf::from_str("")? From 89dea1b1fc80213550eed53fb0020e188f092643 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 5 Jun 2024 05:46:26 -0400 Subject: [PATCH 6/9] Refactor enabling debug logging (#387) * Refactor enabling debug logging * Fix reversed order --- mistralrs-core/src/pipeline/ggml.rs | 3 +++ mistralrs-core/src/pipeline/gguf.rs | 21 ++++----------------- mistralrs-core/src/pipeline/normal.rs | 21 ++++----------------- mistralrs-core/src/utils/debug.rs | 21 +++++++++++++++++++++ mistralrs-core/src/utils/mod.rs | 1 + 5 files changed, 33 insertions(+), 34 deletions(-) create mode 100644 mistralrs-core/src/utils/debug.rs diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index a9f5d7843..3692b0072 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -16,6 +16,7 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; @@ -196,6 +197,8 @@ impl GGMLLoader { tokenizer_json: Option, tgt_non_granular_index: Option, ) -> Self { + setup_logger_and_debug(); + let model_id = if let Some(id) = model_id { id } else { diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 60ab71740..46f5f292c 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -17,6 +17,7 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; @@ -29,7 +30,7 @@ use crate::{ xlora_models::{XLoraQLlama, XLoraQPhi3}, GgufTokenizerConversion, }; -use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG}; +use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline}; use anyhow::{bail, Context, Result}; use candle_core::quantized::GgmlDType; use candle_core::{DType, Device, Tensor}; @@ -45,8 +46,6 @@ use strum::EnumString; use tokenizers::Tokenizer; use tokio::sync::Mutex; use tracing::info; -use tracing::level_filters::LevelFilter; -use tracing_subscriber::EnvFilter; enum Model { Llama(QLlama), @@ -231,6 +230,8 @@ impl GGUFLoader { chat_template: Option, tgt_non_granular_index: Option, ) -> Self { + setup_logger_and_debug(); + let model_id = if let Some(id) = model_id { Some(id) } else if let Some(xlora_order) = xlora_order.clone() { @@ -291,20 +292,6 @@ impl Loader for GGUFLoader { mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { - let is_debug = std::env::var("MISTRALRS_DEBUG") - .unwrap_or_default() - .contains('1'); - DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); - - let filter = EnvFilter::builder() - .with_default_directive(if is_debug { - LevelFilter::INFO.into() - } else { - LevelFilter::DEBUG.into() - }) - .from_env_lossy(); - tracing_subscriber::fmt().with_env_filter(filter).init(); - if in_situ_quant.is_some() { anyhow::bail!( "You are trying to in-situ quantize a GGUF model. This will not do anything." diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 7986ba60e..5ed6c022e 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -20,12 +20,13 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::debug::setup_logger_and_debug; use crate::utils::tokenizer::get_tokenizer; use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors}; use crate::xlora_models::NonGranularState; use crate::{ do_sample, get_mut_arcmutex, get_paths, lora_model_loader, normal_model_loader, - xlora_model_loader, DeviceMapMetadata, Pipeline, DEBUG, + xlora_model_loader, DeviceMapMetadata, Pipeline, }; use anyhow::Result; use candle_core::quantized::GgmlDType; @@ -40,8 +41,6 @@ use std::sync::Arc; use tokenizers::Tokenizer; use tokio::sync::Mutex; use tracing::info; -use tracing::level_filters::LevelFilter; -use tracing_subscriber::EnvFilter; pub struct NormalPipeline { model: Box, @@ -155,6 +154,8 @@ impl NormalLoaderBuilder { } pub fn build(self, loader: NormalLoaderType) -> Box { + setup_logger_and_debug(); + let loader: Box = match loader { NormalLoaderType::Mistral => Box::new(MistralLoader), NormalLoaderType::Gemma => Box::new(GemmaLoader), @@ -213,20 +214,6 @@ impl Loader for NormalLoader { mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { - let is_debug = std::env::var("MISTRALRS_DEBUG") - .unwrap_or_default() - .contains('1'); - DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); - - let filter = EnvFilter::builder() - .with_default_directive(if is_debug { - LevelFilter::INFO.into() - } else { - LevelFilter::DEBUG.into() - }) - .from_env_lossy(); - tracing_subscriber::fmt().with_env_filter(filter).init(); - let config = std::fs::read_to_string(paths.get_config_filename())?; let default_dtype = if device.is_cuda() && mapper.is_dummy() { DType::BF16 diff --git a/mistralrs-core/src/utils/debug.rs b/mistralrs-core/src/utils/debug.rs new file mode 100644 index 000000000..ebb71ff01 --- /dev/null +++ b/mistralrs-core/src/utils/debug.rs @@ -0,0 +1,21 @@ +use tracing::level_filters::LevelFilter; +use tracing_subscriber::EnvFilter; + +use crate::DEBUG; + +// This should be called in each `Loader` when it is created. +pub(crate) fn setup_logger_and_debug() { + let is_debug = std::env::var("MISTRALRS_DEBUG") + .unwrap_or_default() + .contains('1'); + DEBUG.store(is_debug, std::sync::atomic::Ordering::Relaxed); + + let filter = EnvFilter::builder() + .with_default_directive(if is_debug { + LevelFilter::DEBUG.into() + } else { + LevelFilter::INFO.into() + }) + .from_env_lossy(); + tracing_subscriber::fmt().with_env_filter(filter).init(); +} diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index e5911b836..8b6bb9ae0 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod debug; pub(crate) mod model_config; pub(crate) mod progress; pub(crate) mod tokenizer; From 5f5c490fcb30c4198b422fcbee78c3fb5d75732c Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 5 Jun 2024 07:42:59 -0400 Subject: [PATCH 7/9] Support loading chat template from gguf --- mistralrs-core/src/gguf/chat_template.rs | 16 +++ mistralrs-core/src/gguf/mod.rs | 2 + mistralrs-core/src/pipeline/chat_template.rs | 2 +- mistralrs-core/src/pipeline/ggml.rs | 2 +- mistralrs-core/src/pipeline/gguf.rs | 12 +- mistralrs-core/src/pipeline/macros.rs | 34 ++++-- mistralrs-core/src/pipeline/mod.rs | 10 +- mistralrs-core/src/pipeline/normal.rs | 2 +- mistralrs-core/src/pipeline/paths.rs | 110 +++++++++++-------- 9 files changed, 124 insertions(+), 66 deletions(-) create mode 100644 mistralrs-core/src/gguf/chat_template.rs diff --git a/mistralrs-core/src/gguf/chat_template.rs b/mistralrs-core/src/gguf/chat_template.rs new file mode 100644 index 000000000..9dda8a84d --- /dev/null +++ b/mistralrs-core/src/gguf/chat_template.rs @@ -0,0 +1,16 @@ +use super::Content; + +// Get chat template from GGUF metadata if it exists +pub fn get_gguf_chat_template( + content: &Content<'_, R>, +) -> Option { + content + .get_metadata("tokenizer.chat_template") + .ok() + .map(|template| { + template + .to_string() + .expect("Chat template must be a string") + .clone() + }) +} diff --git a/mistralrs-core/src/gguf/mod.rs b/mistralrs-core/src/gguf/mod.rs index c66133ea8..cb6704583 100644 --- a/mistralrs-core/src/gguf/mod.rs +++ b/mistralrs-core/src/gguf/mod.rs @@ -1,5 +1,7 @@ +mod chat_template; mod content; mod gguf_tokenizer; +pub use chat_template::get_gguf_chat_template; pub use content::Content; pub use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion}; diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs index e9d3d892a..1097ea329 100644 --- a/mistralrs-core/src/pipeline/chat_template.rs +++ b/mistralrs-core/src/pipeline/chat_template.rs @@ -37,7 +37,7 @@ pub struct BeginEndUnkTok( ); #[allow(dead_code)] -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Default)] pub struct ChatTemplate { add_bos_token: Option, add_eos_token: Option, diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 3692b0072..6e7021c4d 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -317,7 +317,7 @@ impl Loader for GGMLLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let chat_template = get_chat_template(paths, &self.chat_template); + let chat_template = get_chat_template(paths, &self.chat_template, None); let max_seq_len = match model { Model::Llama(ref l) => l.max_seq_len, diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 46f5f292c..1f87b09ec 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -10,11 +10,11 @@ use super::{ }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; -use crate::gguf::Content; +use crate::gguf::{get_gguf_chat_template, Content}; use crate::lora::Ordering; use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig}; +use crate::pipeline::ChatTemplate; use crate::pipeline::{get_chat_template, Cache}; -use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; use crate::utils::debug::setup_logger_and_debug; @@ -30,7 +30,9 @@ use crate::{ xlora_models::{XLoraQLlama, XLoraQPhi3}, GgufTokenizerConversion, }; -use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline}; +use crate::{ + do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, Pipeline, +}; use anyhow::{bail, Context, Result}; use candle_core::quantized::GgmlDType; use candle_core::{DType, Device, Tensor}; @@ -328,6 +330,8 @@ impl Loader for GGUFLoader { } }; + let gguf_chat_template = get_gguf_chat_template(&content); + let has_adapter = self.kind.is_adapted(); let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); @@ -371,7 +375,7 @@ impl Loader for GGUFLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let mut chat_template = get_chat_template(paths, &self.chat_template); + let mut chat_template = get_chat_template(paths, &self.chat_template, gguf_chat_template); let max_seq_len = match model { Model::Llama(ref l) => l.max_seq_len, diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 37a9ce4eb..66bccdac4 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -165,8 +165,17 @@ macro_rules! get_paths { None }; - info!("Loading `tokenizer_config.json` at `{}`", $this.model_id); - let template_filename = $crate::api_get_file!(api, "tokenizer_config.json", model_id); + let template_filename = if let Some(ref p) = $this.chat_template { + info!("Using chat template file at `{p}`"); + Some(PathBuf::from_str(p)?) + } else { + info!("Loading `tokenizer_config.json` at `{}`", $this.model_id); + Some($crate::api_get_file!( + api, + "tokenizer_config.json", + model_id + )) + }; Ok(Box::new($path_name { tokenizer_filename, @@ -205,17 +214,22 @@ macro_rules! get_paths_gguf { let chat_template = if let Some(ref p) = $this.chat_template { if p.ends_with(".json") { info!("Using chat template file at `{p}`"); - PathBuf::from_str(p)? + Some(PathBuf::from_str(p)?) } else { - PathBuf::from_str("")? + panic!("Specified chat template file must end with .json"); } } else { - info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id); - $crate::api_get_file!( - api, - "tokenizer_config.json", - model_id - ) // Will be loaded from inside gguf file + if $this.model_id.is_none() { + None + } else { + info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id); + let res = $crate::api_get_file!( + api, + "tokenizer_config.json", + model_id + ); + Some(res) + } }; let filenames = get_model_paths( diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 262ec0663..21769f6b0 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -63,8 +63,8 @@ pub trait ModelPaths { /// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer fn get_tokenizer_filename(&self) -> &PathBuf; - /// Content expected to deserialize to [`ChatTemplate`]. - fn get_template_filename(&self) -> &PathBuf; + /// File where the content is expected to deserialize to [`ChatTemplate`]. + fn get_template_filename(&self) -> &Option; /// Optional adapter files. `(String, PathBuf)` is of the form `(id name, path)`. fn get_adapter_filenames(&self) -> &Option>; @@ -98,7 +98,7 @@ pub trait ModelPaths { pub struct LocalModelPaths

{ tokenizer_filename: P, config_filename: P, - template_filename: P, + template_filename: Option

, filenames: Vec

, xlora_adapter_filenames: Option>, xlora_adapter_configs: Option>, @@ -131,7 +131,7 @@ impl

LocalModelPaths

{ Self { tokenizer_filename, config_filename, - template_filename, + template_filename: Some(template_filename), filenames, xlora_adapter_filenames, xlora_adapter_configs, @@ -171,7 +171,7 @@ impl ModelPaths for LocalModelPaths { fn get_ordering(&self) -> &Option { &self.xlora_ordering } - fn get_template_filename(&self) -> &PathBuf { + fn get_template_filename(&self) -> &Option { &self.template_filename } fn get_gen_conf_filename(&self) -> Option<&PathBuf> { diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 5ed6c022e..6a71da6b4 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -292,7 +292,7 @@ impl Loader for NormalLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let chat_template = get_chat_template(paths, &self.chat_template); + let chat_template = get_chat_template(paths, &self.chat_template, None); if let Some(in_situ_quant) = in_situ_quant { model.quantize(in_situ_quant, device.clone())?; diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 312e31ea7..4aefc372d 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -287,38 +287,60 @@ pub fn get_model_paths( /// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`]. /// If the the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not /// have a `chat_template`, use the provided one. +/// +/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file. +/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially. +/// Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this +/// is used.* #[allow(clippy::borrowed_box)] pub(crate) fn get_chat_template( paths: &Box, - chat_template: &Option, + chat_template_fallback: &Option, + chat_template_ovrd: Option, ) -> ChatTemplate { - let template_filename = if paths.get_template_filename().to_string_lossy().is_empty() { - PathBuf::from( - chat_template - .as_ref() - .expect("A tokenizer config or chat template file path must be specified."), - ) + // Get template content, this may be overridden. + let template_content = if let Some(template_filename) = paths.get_template_filename() { + if template_filename + .extension() + .expect("Template filename must be a file") + .to_string_lossy() + != "json" + { + panic!("Template filename {template_filename:?} must end with `.json`."); + } + fs::read_to_string(&template_filename).expect("Loading chat template failed.") } else { - paths.get_template_filename().clone() + if chat_template_fallback + .as_ref() + .is_some_and(|f| f.ends_with(".json")) + { + // User specified a file + let template_filename = chat_template_fallback + .as_ref() + .expect("A tokenizer config or chat template file path must be specified."); + fs::read_to_string(&template_filename).expect("Loading chat template failed.") + } else { + panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there."); + } + }; + + let template: ChatTemplate = match chat_template_ovrd { + Some(chat_template) => { + // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves. + info!("Using literal chat template."); + let mut template = ChatTemplate::default(); + template.chat_template = Some(chat_template); + template + } + None => serde_json::from_str(&template_content).unwrap(), }; - if template_filename - .extension() - .expect("Template filename must be a file") - .to_string_lossy() - != "json" - { - panic!("Template filename {template_filename:?} must end with `.json`."); - } - let template: ChatTemplate = serde_json::from_str( - &fs::read_to_string(&template_filename).expect("Deserialization of chat template failed."), - ) - .unwrap(); #[derive(Debug, serde::Deserialize)] struct SpecifiedTemplate { chat_template: String, bos_token: Option, eos_token: Option, + unk_token: Option, } match &template.chat_template { @@ -326,34 +348,34 @@ pub(crate) fn get_chat_template( None => { info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template."); let mut deser: HashMap = - serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap(); + serde_json::from_str(&template_content).unwrap(); - match chat_template.clone() { + match chat_template_fallback.clone() { Some(t) => { - if t.ends_with(".json") { - info!("Loading specified loading chat template file at `{t}`."); - let templ: SpecifiedTemplate = - serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); + info!("Loading specified loading chat template file at `{t}`."); + let templ: SpecifiedTemplate = + serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); + deser.insert( + "chat_template".to_string(), + Value::String(templ.chat_template), + ); + if templ.bos_token.is_some() { deser.insert( - "chat_template".to_string(), - Value::String(templ.chat_template), + "bos_token".to_string(), + Value::String(templ.bos_token.unwrap()), + ); + } + if templ.eos_token.is_some() { + deser.insert( + "eos_token".to_string(), + Value::String(templ.eos_token.unwrap()), + ); + } + if templ.unk_token.is_some() { + deser.insert( + "unk_token".to_string(), + Value::String(templ.unk_token.unwrap()), ); - if templ.bos_token.is_some() { - deser.insert( - "bos_token".to_string(), - Value::String(templ.bos_token.unwrap()), - ); - } - if templ.eos_token.is_some() { - deser.insert( - "eos_token".to_string(), - Value::String(templ.eos_token.unwrap()), - ); - } - info!("Loaded chat template file."); - } else { - deser.insert("chat_template".to_string(), Value::String(t)); - info!("Loaded specified literal chat template."); } } None => { From 036683b1687b2365b9beee122fa909ceaf5e7c56 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 5 Jun 2024 07:46:28 -0400 Subject: [PATCH 8/9] Some checks and logging --- mistralrs-core/src/gguf/chat_template.rs | 10 +++++++--- mistralrs-core/src/pipeline/gguf.rs | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/gguf/chat_template.rs b/mistralrs-core/src/gguf/chat_template.rs index 9dda8a84d..35aff5c51 100644 --- a/mistralrs-core/src/gguf/chat_template.rs +++ b/mistralrs-core/src/gguf/chat_template.rs @@ -1,6 +1,8 @@ +use tracing::info; + use super::Content; -// Get chat template from GGUF metadata if it exists +// Get chat template from GGUF metadata if it exists. pub fn get_gguf_chat_template( content: &Content<'_, R>, ) -> Option { @@ -8,9 +10,11 @@ pub fn get_gguf_chat_template( .get_metadata("tokenizer.chat_template") .ok() .map(|template| { - template + let template = template .to_string() .expect("Chat template must be a string") - .clone() + .clone(); + info!("Discovered and using GGUF chat template: `{template}`"); + template }) } diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 1f87b09ec..0af9fa74c 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -330,7 +330,12 @@ impl Loader for GGUFLoader { } }; - let gguf_chat_template = get_gguf_chat_template(&content); + // Only load gguf chat template if there is nothing else + let gguf_chat_template = if paths.get_template_filename().is_none() { + get_gguf_chat_template(&content) + } else { + None + }; let has_adapter = self.kind.is_adapted(); let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); From ff7e14b0aa70c18bf4d7f99c25eaacfadb15ce52 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Wed, 5 Jun 2024 08:26:30 -0400 Subject: [PATCH 9/9] Clippy --- mistralrs-core/src/pipeline/paths.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 4aefc372d..cc792fec9 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -308,20 +308,18 @@ pub(crate) fn get_chat_template( { panic!("Template filename {template_filename:?} must end with `.json`."); } - fs::read_to_string(&template_filename).expect("Loading chat template failed.") - } else { - if chat_template_fallback + fs::read_to_string(template_filename).expect("Loading chat template failed.") + } else if chat_template_fallback + .as_ref() + .is_some_and(|f| f.ends_with(".json")) + { + // User specified a file + let template_filename = chat_template_fallback .as_ref() - .is_some_and(|f| f.ends_with(".json")) - { - // User specified a file - let template_filename = chat_template_fallback - .as_ref() - .expect("A tokenizer config or chat template file path must be specified."); - fs::read_to_string(&template_filename).expect("Loading chat template failed.") - } else { - panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there."); - } + .expect("A tokenizer config or chat template file path must be specified."); + fs::read_to_string(template_filename).expect("Loading chat template failed.") + } else { + panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there."); }; let template: ChatTemplate = match chat_template_ovrd {