diff --git a/Cargo.toml b/Cargo.toml index e8482cf1ff..f5d322a42c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } futures = "0.3" clap = { version = "4.5.1", features = ["derive"] } -pyo3 = { version = "0.21.0", features = ["full", "extension-module"] } +pyo3 = { version = "0.21.0", features = ["full", "extension-module", "either"] } tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] } once_cell = "1.19.0" image = "0.25.1" diff --git a/README.md b/README.md index 809c96b776..1b41bd5125 100644 --- a/README.md +++ b/README.md @@ -284,6 +284,14 @@ please consider using the method demonstrated in examples below, where the token **Supported GGUF tokenizer types** - `llama` +Some GGUF models are very large and are sharded into multiple files. Mistral.rs supports this, and to use it, delimit the `.gguf` filenames with a space as such: + +```bash +./mistralrs-server --chat-template gguf -m . -f "a.gguf b.gguf" +``` + +For the Python API, a list of strings is also accepted for this case. + ## Run To start a server serving Mistral GGUF on `localhost:1234`, diff --git a/mistralrs-core/src/device_map.rs b/mistralrs-core/src/device_map.rs index f0aacfdf67..23062ca7e1 100644 --- a/mistralrs-core/src/device_map.rs +++ b/mistralrs-core/src/device_map.rs @@ -18,6 +18,7 @@ impl DeviceMapMetadata { host_layers: None, } } + /// A device mapper to not map device. pub fn dummy() -> Self { Self { device_layers: None, diff --git a/mistralrs-core/src/gguf/chat_template.rs b/mistralrs-core/src/gguf/chat_template.rs new file mode 100644 index 0000000000..35aff5c516 --- /dev/null +++ b/mistralrs-core/src/gguf/chat_template.rs @@ -0,0 +1,20 @@ +use tracing::info; + +use super::Content; + +// Get chat template from GGUF metadata if it exists. +pub fn get_gguf_chat_template( + content: &Content<'_, R>, +) -> Option { + content + .get_metadata("tokenizer.chat_template") + .ok() + .map(|template| { + let template = template + .to_string() + .expect("Chat template must be a string") + .clone(); + info!("Discovered and using GGUF chat template: `{template}`"); + template + }) +} diff --git a/mistralrs-core/src/gguf/content.rs b/mistralrs-core/src/gguf/content.rs new file mode 100644 index 0000000000..52e9724346 --- /dev/null +++ b/mistralrs-core/src/gguf/content.rs @@ -0,0 +1,167 @@ +use std::fs; + +use anyhow::Context; +use candle_core::{ + quantized::{ + gguf_file::{self, Value}, + QTensor, + }, + Device, Result, +}; +use indexmap::IndexMap; +use tracing::info; + +use crate::{pipeline::GGUFArchitecture, DEBUG}; + +fn parse_gguf_value(value: &Value) -> String { + match value { + Value::Array(vs) => vs + .iter() + .map(parse_gguf_value) + .collect::>() + .join(", "), + Value::Bool(b) => b.to_string(), + Value::F32(x) => x.to_string(), + Value::F64(x) => x.to_string(), + Value::I8(x) => x.to_string(), + Value::I16(x) => x.to_string(), + Value::I32(x) => x.to_string(), + Value::I64(x) => x.to_string(), + Value::String(x) => x.to_string(), + Value::U8(x) => x.to_string(), + Value::U16(x) => x.to_string(), + Value::U32(x) => x.to_string(), + Value::U64(x) => x.to_string(), + } +} + +// Internal invariant: contents and readers must be paired. +/// This abstracts the files for a GGUF model and enables multiple files to be used. +pub struct Content<'a, R: std::io::Seek + std::io::Read> { + contents: Vec, + readers: &'a mut [&'a mut R], + arch: GGUFArchitecture, +} + +impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> { + /// Create a `Content` from a set of file readers. + pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result { + let mut contents = Vec::new(); + let n_readers = readers.len(); + for reader in readers.iter_mut() { + contents.push(gguf_file::Content::read(reader)?); + } + let n_splits = contents + .iter() + .filter_map(|ct| { + ct.metadata + .get("split.count") + .map(|val| val.to_u64().unwrap()) + }) + .collect::>(); + if n_splits.len() > 1 { + candle_core::bail!("Multiple contents have multiple `split.count` fields"); + } + #[allow(clippy::cast_possible_truncation)] + if !n_splits.is_empty() && n_readers != n_splits[0] as usize { + candle_core::bail!("Number of readers does not match the number of splits."); + } else if n_splits.len() == 1 { + info!("Model n splits: {}", n_splits[0]); + } + + let mut arch = None; + for ct in &contents { + if !ct.metadata.contains_key("general.architecture") { + continue; + } + + arch = Some( + ct.metadata["general.architecture"] + .to_string() + .context("Model metadata should have declared an architecture") + .and_then(GGUFArchitecture::from_value) + .unwrap(), + ); + } + let arch = arch.expect("GGUF files must specify `general.architecture`"); + Ok(Self { + contents, + readers, + arch, + }) + } + + pub fn arch(&self) -> GGUFArchitecture { + self.arch + } + + /// Retrieve a tensor, searching through each content. + pub fn tensor(&mut self, name: &str, device: &Device) -> Result { + for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) { + if let Some(tensor_info) = ct.tensor_infos.get(name) { + return tensor_info.read(reader, ct.tensor_data_offset, device); + } + } + candle_core::bail!("Cannot find tensor info for {name}") + } + + /// Print metadata for these contents. + /// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled. + pub fn print_metadata(&self) -> anyhow::Result<()> { + // Find the ct with general.architecture + let mut keys = Vec::new(); + let mut metadatas = Vec::new(); + let mut tensors = Vec::new(); + for ct in &self.contents { + keys.extend(ct.metadata.keys()); + metadatas.push(&ct.metadata); + + if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { + for (name, info) in &ct.tensor_infos { + tensors.push(format!( + "name = `{name}`, shape = {:?}, dtype = {:?}", + info.shape.clone(), + info.ggml_dtype + )); + } + } + } + + info!("Model config:"); + keys.sort(); + let mut output_keys = IndexMap::new(); + for name in keys { + if !name.contains("tokenizer") { + for metadata in &metadatas { + if let Some(val) = metadata.get(name) { + output_keys.insert(name, parse_gguf_value(val)); + } + } + } + } + for (name, val) in output_keys { + println!("{name}: {val}") + } + + if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { + fs::write( + "mistralrs_gguf_tensors.txt", + serde_json::to_string_pretty(&tensors).expect("Serialization failed."), + )?; + + info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); + } + + anyhow::Ok(()) + } + + /// Get metadata + pub fn get_metadata(&self, name: &str) -> Result<&Value> { + for content in &self.contents { + if let Some(v) = content.metadata.get(name) { + return Ok(v); + } + } + candle_core::bail!("Cannot find metadata for {name}") + } +} diff --git a/mistralrs-core/src/pipeline/gguf_tokenizer.rs b/mistralrs-core/src/gguf/gguf_tokenizer.rs similarity index 90% rename from mistralrs-core/src/pipeline/gguf_tokenizer.rs rename to mistralrs-core/src/gguf/gguf_tokenizer.rs index 4d0ce613ce..363aaffb0c 100644 --- a/mistralrs-core/src/pipeline/gguf_tokenizer.rs +++ b/mistralrs-core/src/gguf/gguf_tokenizer.rs @@ -1,7 +1,6 @@ use std::sync::atomic::Ordering; use anyhow::Result; -use candle_core::quantized::gguf_file::Content; use tokenizers::{ decoders::{self, byte_fallback::ByteFallback, fuse::Fuse, strip::Strip}, models::unigram::Unigram, @@ -12,27 +11,33 @@ use tracing::info; use crate::DEBUG; -pub struct ConversionResult { +use super::Content; + +pub struct GgufTokenizerConversion { pub tokenizer: Tokenizer, pub bos: Option, pub eos: Option, pub unk: Option, } -pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { - let model = content.metadata["tokenizer.ggml.model"] +// Convert GGUF tokenizer to tokenizer and metadata +pub fn convert_gguf_to_hf_tokenizer( + content: &Content<'_, R>, +) -> Result { + let model = content + .get_metadata("tokenizer.ggml.model")? .to_string() .expect("GGUF tokenizer model is not a string.") .clone(); - let tokens = content.metadata["tokenizer.ggml.tokens"] + let tokens = content + .get_metadata("tokenizer.ggml.tokens")? .to_vec() .expect("GGUF tokenizer tokens is not a vec.") .iter() .map(|t| t.to_string().expect("GGUF token is not a string.").clone()) .collect::>(); let added_tokens = content - .metadata - .get("tokenizer.ggml.added_tokens") + .get_metadata("tokenizer.ggml.added_tokens") .map(|items| { items .to_vec() @@ -45,7 +50,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result>() }); - let scores = content.metadata.get("tokenizer.ggml.scores").map(|items| { + let scores = content.get_metadata("tokenizer.ggml.scores").map(|items| { items .to_vec() .expect("GGUF tokenizer scores is not a vec.") @@ -53,7 +58,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result>() }); - let merges = content.metadata.get("tokenizer.ggml.merges").map(|items| { + let merges = content.get_metadata("tokenizer.ggml.merges").map(|items| { items .to_vec() .expect("GGUF tokenizer merges is not a vec.") @@ -63,15 +68,16 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result Result Result( - ct: &gguf_file::Content, - r: &mut R, + ct: &mut Content<'_, R>, name: &str, device: &Device, ) -> Result { - let w = ct.tensor(r, &format!("{name}.weight"), device)?; - let b = ct.tensor(r, &format!("{name}.bias"), device)?; + let w = ct.tensor(&format!("{name}.weight"), device)?; + let b = ct.tensor(&format!("{name}.bias"), device)?; let inner = QMatMul::from_qtensor(w)?; let bias = b.dequantize(device)?; Ok(Self { diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 3876cfdc0d..0be8760b5c 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -26,6 +26,7 @@ mod model_selected; pub use model_selected::ModelSelected; mod cublaslt; +pub mod gguf; pub mod layers; mod layers_masker; mod layers_utils; @@ -42,11 +43,12 @@ mod utils; mod xlora_models; pub use device_map::{DeviceMapMetadata, LayerDeviceMapper}; +pub use gguf::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion}; pub use pipeline::{ - GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, - GGUFSpecificConfig, GemmaLoader, LlamaLoader, Loader, LocalModelPaths, MistralLoader, - MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder, NormalLoaderType, - NormalSpecificConfig, Phi2Loader, Phi3Loader, Qwen2Loader, SpeculativeConfig, + GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFArchitecture, GGUFLoader, + GGUFLoaderBuilder, GGUFSpecificConfig, GemmaLoader, LlamaLoader, Loader, LocalModelPaths, + MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder, + NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, TokenSource, }; pub use request::{Constraint, Content, NormalRequest, Request, RequestMessage}; @@ -60,6 +62,8 @@ pub use toml_selector::{TomlLoaderArgs, TomlSelector}; /// `true` if `MISTRALRS_DEBUG=1` pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false); +/// Delimiter for GGUF multiple files in the CLI. +pub const GGUF_MULTI_FILE_DELIMITER: &str = " "; /// The MistralRs struct handles sending requests to the engine. /// It is the core multi-threaded component of mistral.rs, and uses `mspc` diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index 3d61eb62cc..c36d198045 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -6,6 +6,7 @@ use crate::{ NormalSpecificConfig, }, Loader, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, + GGUF_MULTI_FILE_DELIMITER, }; pub struct LoaderBuilder { @@ -158,7 +159,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .build(), ModelSelected::XLoraGGUF { @@ -174,7 +178,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .with_xlora( xlora_model_id, @@ -198,7 +205,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result>(), ) .with_lora( adapters_model_id, diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index a5771038c5..15bb3a58ad 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -1,16 +1,16 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +use candle_core::quantized::ggml_file; use candle_core::quantized::QMatMul; -use candle_core::quantized::{ggml_file, gguf_file}; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, RotaryEmbedding}; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention, }; use crate::pipeline::{extract_logits, Cache}; -use crate::utils::max_seq_len::get_gguf_max_seq_len; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; @@ -260,49 +260,50 @@ impl ModelConfig::FromGGML for ModelWeights { impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; verify_sanity_gguf( - md_get("general.architecture")?.to_string().unwrap(), + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), "llama", )?; // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") + let n_expert = ct + .get_metadata("llama.expert_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") + let n_expert_used = ct + .get_metadata("llama.expert_used_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + let head_count = ct.get_metadata("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("llama.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = ct + .get_metadata("llama.attention.layer_norm_rms_epsilon")? + .to_f32()?; - let max_seq_len = - get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize; + let max_seq_len = ct + .get_metadata("llama.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; - let rope_freq_base = md_get("llama.rope.freq_base") + let rope_freq_base = ct + .get_metadata("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let head_dim = embedding_length / head_count; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = QRmsNorm::new( - ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, - )?; - let output = ct.tensor(reader, "output.weight", device)?; + let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?; + let output = ct.tensor("output.weight", device)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { @@ -318,18 +319,14 @@ impl ModelConfig::FromGGUF for ModelWeights { DType::F32, )?; - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wo = - ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; - let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -337,15 +334,15 @@ impl ModelConfig::FromGGUF for ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?; experts.push(Mlp { feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, @@ -358,9 +355,8 @@ impl ModelConfig::FromGGUF for ModelWeights { experts, } }; - let attention_norm = - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?; layers.push(LayerWeights { attention_wq: QMatMul::from_qtensor(attention_wq)?, attention_wk: QMatMul::from_qtensor(attention_wk)?, diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index 7a752f7873..83a61a9cc6 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -1,15 +1,14 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::quantized::gguf_file; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm}; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::ScaledDotProductAttention; use crate::layers::{repeat_kv, CausalMasker, QLinear}; use crate::pipeline::{extract_logits, Cache}; -use crate::utils::max_seq_len::get_gguf_max_seq_len; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; @@ -145,53 +144,52 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - // Parameter extraction from metadata. - let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi2.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize; - let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64; - let max_seq_len = - get_gguf_max_seq_len(md_get("phi2.context_length"), MAX_SEQ_LEN as u64) as usize; + let head_count = ct.get_metadata("phi2.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi2.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi2.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi2.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi2.rope.dimension_count")?.to_u32()? as usize; + let ln_eps = ct + .get_metadata("phi2.attention.layer_norm_epsilon")? + .to_f32()? as f64; + + let max_seq_len = ct + .get_metadata("phi2.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; let output_norm = layer_norm( - ct.tensor(reader, "output_norm.weight", device)?, - ct.tensor(reader, "output_norm.bias", device)?, + ct.tensor("output_norm.weight", device)?, + ct.tensor("output_norm.bias", device)?, ln_eps, )?; - let output = QLinear::new(&ct, reader, "output", device)?; + let output = QLinear::new(&mut ct, "output", device)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; - let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; + let ffn_up = QLinear::new(&mut ct, &format!("{prefix}.ffn_up"), device)?; + let ffn_down = QLinear::new(&mut ct, &format!("{prefix}.ffn_down"), device)?; let mlp = Mlp { ffn_up, ffn_down }; let attn_norm = layer_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, - ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.bias"), device)?, ln_eps, )?; layers.push(LayerWeights { - attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, - attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, + attn_qkv: QLinear::new(&mut ct, &format!("{prefix}.attn_qkv"), device)?, + attn_output: QLinear::new(&mut ct, &format!("{prefix}.attn_output"), device)?, attn_norm, mlp, n_head: head_count, diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 8149eea845..b461e5abdd 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -1,13 +1,13 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention, }; use crate::pipeline::Cache; use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; -use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; @@ -162,71 +162,63 @@ fn precomput_freqs_cis( impl ModelConfig::FromGGUF for ModelWeights { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, mapper: DeviceMapMetadata, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?; + verify_sanity_gguf( + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), + "phi3", + )?; // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let context_window = md_get("phi3.context_length")?.to_u32()? as usize; + let head_count = ct.get_metadata("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi3.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi3.embedding_length")?.to_u32()? as usize; + let i_size = ct.get_metadata("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = ct + .get_metadata("phi3.attention.layer_norm_rms_epsilon")? + .to_f32()? as f64; + let context_window = ct.get_metadata("phi3.context_length")?.to_u32()? as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QMatMul::from_qtensor(ct.tensor(reader, "output.weight", device)?)?; + let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?; + let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.ffn_up.weight"), - device, - )?)?; - let ffn_down = QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.ffn_down.weight"), - device, - )?)?; + let ffn_up = + QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?)?; + let ffn_down = + QMatMul::from_qtensor(ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?)?; let mlp = Mlp { ffn_up, ffn_down, i_size, }; let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, rms_eps, )?; let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?, rms_eps, )?; layers.push(LayerWeights { - attn_qkv: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_qkv.weight"), - device, - )?)?, - attn_output: QMatMul::from_qtensor(ct.tensor( - reader, - &format!("{prefix}.attn_output.weight"), - device, - )?)?, + attn_qkv: QMatMul::from_qtensor( + ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?, + )?, + attn_output: QMatMul::from_qtensor( + ct.tensor(&format!("{prefix}.attn_output.weight"), device)?, + )?, attn_norm, ffn_norm, mlp, diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs index e9d3d892af..1097ea3299 100644 --- a/mistralrs-core/src/pipeline/chat_template.rs +++ b/mistralrs-core/src/pipeline/chat_template.rs @@ -37,7 +37,7 @@ pub struct BeginEndUnkTok( ); #[allow(dead_code)] -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Default)] pub struct ChatTemplate { add_bos_token: Option, add_eos_token: Option, diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index a71bdf400b..6e7021c4d6 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -317,7 +317,7 @@ impl Loader for GGMLLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let chat_template = get_chat_template(paths, &self.chat_template); + let chat_template = get_chat_template(paths, &self.chat_template, None); let max_seq_len = match model { Model::Llama(ref l) => l.max_seq_len, @@ -372,7 +372,7 @@ impl Loader for GGMLLoader { revision, self, self.quantized_model_id, - self.quantized_filename, + Some(vec![self.quantized_filename.as_ref().unwrap().clone()]), silent ); self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 7d74bdf950..0af9fa74ca 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -10,30 +10,31 @@ use super::{ }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; +use crate::gguf::{get_gguf_chat_template, Content}; use crate::lora::Ordering; use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig}; -use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult}; +use crate::pipeline::ChatTemplate; use crate::pipeline::{get_chat_template, Cache}; -use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; -use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG}; use crate::{ + convert_gguf_to_hf_tokenizer, models::quantized_llama::ModelWeights as QLlama, models::quantized_phi2::ModelWeights as QPhi, models::quantized_phi3::ModelWeights as QPhi3, utils::tokens::get_token, xlora_models::{XLoraQLlama, XLoraQPhi3}, + GgufTokenizerConversion, }; -use anyhow::{bail, Context, Result}; -use candle_core::quantized::{ - gguf_file::{self, Value as GgufValue}, - GgmlDType, +use crate::{ + do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, Pipeline, }; +use anyhow::{bail, Context, Result}; +use candle_core::quantized::GgmlDType; use candle_core::{DType, Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -71,7 +72,7 @@ pub struct GGUFLoader { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, xlora_order: Option, no_kv_cache: bool, @@ -80,9 +81,9 @@ pub struct GGUFLoader { tgt_non_granular_index: Option, } -#[derive(Debug, EnumString)] +#[derive(Copy, Clone, Debug, EnumString)] #[strum(serialize_all = "kebab-case")] -enum GGUFArchitecture { +pub enum GGUFArchitecture { Llama, Mpt, Gptneox, @@ -100,7 +101,8 @@ enum GGUFArchitecture { // - Case-insensitive variant matching (TODO: is this desirable?) // - Customized error until potential upstream support: https://github.com/Peternator7/strum/issues/332 impl GGUFArchitecture { - fn from_value + std::fmt::Display>(value: T) -> Result { + /// GGUF architecture from a kebab-case representation. + pub fn from_value + std::fmt::Display>(value: T) -> Result { Self::from_str(&value.as_ref().to_ascii_lowercase()) .with_context(|| format!("Unknown GGUF architecture `{value}`")) .map_err(anyhow::Error::msg) @@ -119,7 +121,7 @@ pub struct GGUFLoaderBuilder { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, kind: ModelKind, xlora_order: Option, @@ -137,7 +139,7 @@ impl GGUFLoaderBuilder { chat_template: Option, tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, ) -> Self { let kind = ModelKind::Quantized { quant: QuantizationKind::Gguf, @@ -222,7 +224,7 @@ impl GGUFLoader { model_id: Option, config: GGUFSpecificConfig, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Vec, xlora_model_id: Option, kind: ModelKind, xlora_order: Option, @@ -258,28 +260,6 @@ impl GGUFLoader { } } -fn parse_gguf_value(value: &GgufValue) -> String { - match value { - GgufValue::Array(vs) => vs - .iter() - .map(parse_gguf_value) - .collect::>() - .join(", "), - GgufValue::Bool(b) => b.to_string(), - GgufValue::F32(x) => x.to_string(), - GgufValue::F64(x) => x.to_string(), - GgufValue::I8(x) => x.to_string(), - GgufValue::I16(x) => x.to_string(), - GgufValue::I32(x) => x.to_string(), - GgufValue::I64(x) => x.to_string(), - GgufValue::String(x) => x.to_string(), - GgufValue::U8(x) => x.to_string(), - GgufValue::U16(x) => x.to_string(), - GgufValue::U32(x) => x.to_string(), - GgufValue::U64(x) => x.to_string(), - } -} - impl Loader for GGUFLoader { #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn load_model_from_hf( @@ -324,50 +304,25 @@ impl Loader for GGUFLoader { info!("Loading model `{}` on {device:?}...", self.get_id()); } - let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?; - let model = gguf_file::Content::read(&mut file) - .map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?; - let arch = model.metadata["general.architecture"] - .to_string() - .context("Model metadata should have declared an architecture") - .and_then(GGUFArchitecture::from_value)?; - - info!("Model config:"); - let mut sorted_keys = model.metadata.keys().collect::>(); - sorted_keys.sort(); - for name in sorted_keys { - if !name.contains("tokenizer") { - let value = parse_gguf_value(&model.metadata[name]); - println!("{name}: {}", value); - } + let mut files = Vec::new(); + for weight_filename in paths.get_weight_filenames() { + files.push(std::fs::File::open(weight_filename)?); } + let mut files = files.iter_mut().collect::>(); - if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { - let mut tensors = Vec::new(); - for (name, info) in &model.tensor_infos { - tensors.push(format!( - "name = `{name}`, shape = {:?}, dtype = {:?}", - info.shape.clone(), - info.ggml_dtype - )); - } - fs::write( - "mistralrs_gguf_tensors.txt", - serde_json::to_string_pretty(&tensors).expect("Serialization failed."), - )?; - - info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); - } + let content = Content::from_readers(&mut files)?; + let arch = content.arch(); - let ConversionResult { + // Set bos/eos/unk to None to avoid the override + let GgufTokenizerConversion { tokenizer, bos, eos, unk, } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() { - convert_ggml_to_hf_tokenizer(&model)? + convert_gguf_to_hf_tokenizer(&content)? } else { - ConversionResult { + GgufTokenizerConversion { tokenizer: get_tokenizer(paths.get_tokenizer_filename(), None)?, bos: None, eos: None, @@ -375,12 +330,19 @@ impl Loader for GGUFLoader { } }; + // Only load gguf chat template if there is nothing else + let gguf_chat_template = if paths.get_template_filename().is_none() { + get_gguf_chat_template(&content) + } else { + None + }; + let has_adapter = self.kind.is_adapted(); let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); let model_config = { // Base config (quantization only): - let quant = ModelConfig::ParamsGGUF((model, &mut file).into(), (device, mapper).into()); + let quant = ModelConfig::ParamsGGUF(content, (device, mapper).into()); // With optional adapter config: let mut adapter = None; @@ -418,7 +380,7 @@ impl Loader for GGUFLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let mut chat_template = get_chat_template(paths, &self.chat_template); + let mut chat_template = get_chat_template(paths, &self.chat_template, gguf_chat_template); let max_seq_len = match model { Model::Llama(ref l) => l.max_seq_len, diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 37a9ce4eb9..66bccdac45 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -165,8 +165,17 @@ macro_rules! get_paths { None }; - info!("Loading `tokenizer_config.json` at `{}`", $this.model_id); - let template_filename = $crate::api_get_file!(api, "tokenizer_config.json", model_id); + let template_filename = if let Some(ref p) = $this.chat_template { + info!("Using chat template file at `{p}`"); + Some(PathBuf::from_str(p)?) + } else { + info!("Loading `tokenizer_config.json` at `{}`", $this.model_id); + Some($crate::api_get_file!( + api, + "tokenizer_config.json", + model_id + )) + }; Ok(Box::new($path_name { tokenizer_filename, @@ -205,17 +214,22 @@ macro_rules! get_paths_gguf { let chat_template = if let Some(ref p) = $this.chat_template { if p.ends_with(".json") { info!("Using chat template file at `{p}`"); - PathBuf::from_str(p)? + Some(PathBuf::from_str(p)?) } else { - PathBuf::from_str("")? + panic!("Specified chat template file must end with .json"); } } else { - info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id); - $crate::api_get_file!( - api, - "tokenizer_config.json", - model_id - ) // Will be loaded from inside gguf file + if $this.model_id.is_none() { + None + } else { + info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id); + let res = $crate::api_get_file!( + api, + "tokenizer_config.json", + model_id + ); + Some(res) + } }; let filenames = get_model_paths( diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 4ea3f6d447..21769f6b03 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -2,7 +2,6 @@ mod cache_manager; mod chat_template; mod ggml; mod gguf; -mod gguf_tokenizer; mod inputs_processor; mod isq; mod macros; @@ -21,7 +20,7 @@ use candle_core::quantized::GgmlDType; use chat_template::ChatTemplate; use core::fmt; pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig}; -pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig}; +pub use gguf::{GGUFArchitecture, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig}; pub use isq::IsqModel; pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig}; pub use normal_loaders::{ @@ -64,8 +63,8 @@ pub trait ModelPaths { /// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer fn get_tokenizer_filename(&self) -> &PathBuf; - /// Content expected to deserialize to [`ChatTemplate`]. - fn get_template_filename(&self) -> &PathBuf; + /// File where the content is expected to deserialize to [`ChatTemplate`]. + fn get_template_filename(&self) -> &Option; /// Optional adapter files. `(String, PathBuf)` is of the form `(id name, path)`. fn get_adapter_filenames(&self) -> &Option>; @@ -99,7 +98,7 @@ pub trait ModelPaths { pub struct LocalModelPaths

{ tokenizer_filename: P, config_filename: P, - template_filename: P, + template_filename: Option

, filenames: Vec

, xlora_adapter_filenames: Option>, xlora_adapter_configs: Option>, @@ -132,7 +131,7 @@ impl

LocalModelPaths

{ Self { tokenizer_filename, config_filename, - template_filename, + template_filename: Some(template_filename), filenames, xlora_adapter_filenames, xlora_adapter_configs, @@ -172,7 +171,7 @@ impl ModelPaths for LocalModelPaths { fn get_ordering(&self) -> &Option { &self.xlora_ordering } - fn get_template_filename(&self) -> &PathBuf { + fn get_template_filename(&self) -> &Option { &self.template_filename } fn get_gen_conf_filename(&self) -> Option<&PathBuf> { diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 5ed6c022e2..6a71da6b43 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -292,7 +292,7 @@ impl Loader for NormalLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let chat_template = get_chat_template(paths, &self.chat_template); + let chat_template = get_chat_template(paths, &self.chat_template, None); if let Some(in_situ_quant) = in_situ_quant { model.quantize(in_situ_quant, device.clone())?; diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index 0aea537d43..cc792fec9c 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -2,7 +2,6 @@ use std::{ collections::HashMap, fs, path::{Path, PathBuf}, - str::FromStr, }; use anyhow::Result; @@ -251,14 +250,16 @@ pub fn get_model_paths( revision: String, token_source: &TokenSource, quantized_model_id: &Option, - quantized_filename: &Option, + quantized_filename: &Option>, api: &ApiRepo, model_id: &Path, ) -> Result> { match &quantized_filename { - Some(name) => match quantized_model_id.as_ref().unwrap().as_str() { - "" => Ok(vec![PathBuf::from_str(name).unwrap()]), - id => { + Some(names) => { + let id = quantized_model_id.as_ref().unwrap(); + let mut files = Vec::new(); + + for name in names { let qapi = ApiBuilder::new() .with_progress(true) .with_token(get_token(token_source)?) @@ -269,9 +270,10 @@ pub fn get_model_paths( revision.clone(), )); let model_id = Path::new(&id); - Ok(vec![api_get_file!(qapi, name, model_id)]) + files.push(api_get_file!(qapi, name, model_id)); } - }, + Ok(files) + } None => { let mut filenames = vec![]; for rfilename in api_dir_list!(api, model_id).filter(|x| x.ends_with(".safetensors")) { @@ -285,38 +287,58 @@ pub fn get_model_paths( /// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`]. /// If the the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not /// have a `chat_template`, use the provided one. +/// +/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file. +/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially. +/// Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this +/// is used.* #[allow(clippy::borrowed_box)] pub(crate) fn get_chat_template( paths: &Box, - chat_template: &Option, + chat_template_fallback: &Option, + chat_template_ovrd: Option, ) -> ChatTemplate { - let template_filename = if paths.get_template_filename().to_string_lossy().is_empty() { - PathBuf::from( - chat_template - .as_ref() - .expect("A tokenizer config or chat template file path must be specified."), - ) + // Get template content, this may be overridden. + let template_content = if let Some(template_filename) = paths.get_template_filename() { + if template_filename + .extension() + .expect("Template filename must be a file") + .to_string_lossy() + != "json" + { + panic!("Template filename {template_filename:?} must end with `.json`."); + } + fs::read_to_string(template_filename).expect("Loading chat template failed.") + } else if chat_template_fallback + .as_ref() + .is_some_and(|f| f.ends_with(".json")) + { + // User specified a file + let template_filename = chat_template_fallback + .as_ref() + .expect("A tokenizer config or chat template file path must be specified."); + fs::read_to_string(template_filename).expect("Loading chat template failed.") } else { - paths.get_template_filename().clone() + panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there."); + }; + + let template: ChatTemplate = match chat_template_ovrd { + Some(chat_template) => { + // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves. + info!("Using literal chat template."); + let mut template = ChatTemplate::default(); + template.chat_template = Some(chat_template); + template + } + None => serde_json::from_str(&template_content).unwrap(), }; - if template_filename - .extension() - .expect("Template filename must be a file") - .to_string_lossy() - != "json" - { - panic!("Template filename {template_filename:?} must end with `.json`."); - } - let template: ChatTemplate = serde_json::from_str( - &fs::read_to_string(&template_filename).expect("Deserialization of chat template failed."), - ) - .unwrap(); #[derive(Debug, serde::Deserialize)] struct SpecifiedTemplate { chat_template: String, bos_token: Option, eos_token: Option, + unk_token: Option, } match &template.chat_template { @@ -324,34 +346,34 @@ pub(crate) fn get_chat_template( None => { info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template."); let mut deser: HashMap = - serde_json::from_str(&fs::read_to_string(&template_filename).unwrap()).unwrap(); + serde_json::from_str(&template_content).unwrap(); - match chat_template.clone() { + match chat_template_fallback.clone() { Some(t) => { - if t.ends_with(".json") { - info!("Loading specified loading chat template file at `{t}`."); - let templ: SpecifiedTemplate = - serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); + info!("Loading specified loading chat template file at `{t}`."); + let templ: SpecifiedTemplate = + serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap(); + deser.insert( + "chat_template".to_string(), + Value::String(templ.chat_template), + ); + if templ.bos_token.is_some() { deser.insert( - "chat_template".to_string(), - Value::String(templ.chat_template), + "bos_token".to_string(), + Value::String(templ.bos_token.unwrap()), + ); + } + if templ.eos_token.is_some() { + deser.insert( + "eos_token".to_string(), + Value::String(templ.eos_token.unwrap()), + ); + } + if templ.unk_token.is_some() { + deser.insert( + "unk_token".to_string(), + Value::String(templ.unk_token.unwrap()), ); - if templ.bos_token.is_some() { - deser.insert( - "bos_token".to_string(), - Value::String(templ.bos_token.unwrap()), - ); - } - if templ.eos_token.is_some() { - deser.insert( - "eos_token".to_string(), - Value::String(templ.eos_token.unwrap()), - ); - } - info!("Loaded chat template file."); - } else { - deser.insert("chat_template".to_string(), Value::String(t)); - info!("Loaded specified literal chat template."); } } None => { diff --git a/mistralrs-core/src/toml_selector.rs b/mistralrs-core/src/toml_selector.rs index 478d940eb0..ae7a2d2501 100644 --- a/mistralrs-core/src/toml_selector.rs +++ b/mistralrs-core/src/toml_selector.rs @@ -5,7 +5,7 @@ use serde::Deserialize; use crate::{ GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig, - SpeculativeLoader, + SpeculativeLoader, GGUF_MULTI_FILE_DELIMITER, }; fn default_repeat_last_n() -> usize { @@ -307,7 +307,10 @@ fn loader_from_selected( args.chat_template, Some(tok_model_id), quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .build(), TomlModelSelected::XLoraGGUF { @@ -324,7 +327,10 @@ fn loader_from_selected( args.chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .with_xlora( xlora_model_id, @@ -349,7 +355,10 @@ fn loader_from_selected( args.chat_template, tok_model_id, quantized_model_id, - quantized_filename, + quantized_filename + .split(GGUF_MULTI_FILE_DELIMITER) + .map(|s| s.to_string()) + .collect::>(), ) .with_lora( adapters_model_id, diff --git a/mistralrs-core/src/utils/max_seq_len.rs b/mistralrs-core/src/utils/max_seq_len.rs deleted file mode 100644 index 3ac96a29cc..0000000000 --- a/mistralrs-core/src/utils/max_seq_len.rs +++ /dev/null @@ -1,20 +0,0 @@ -use candle_core::{ - quantized::gguf_file::{Value, ValueType}, - Result, -}; -use tracing::warn; - -/// Extract a u32 or u8 max seq len. Warns if error and then uses a default -pub(crate) fn get_gguf_max_seq_len(max_seq_len: Result<&Value>, default: u64) -> u64 { - match max_seq_len { - Ok(m) => match m.value_type() { - ValueType::U32 => m.to_u32().unwrap() as u64, - ValueType::U64 => m.to_u64().unwrap(), - _ => default, - }, - Err(_) => { - warn!("GGUF file does not specify a context window, using {default}."); - default - } - } -} diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 854aa120ad..8b6bb9ae04 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod debug; -pub(crate) mod max_seq_len; pub(crate) mod model_config; pub(crate) mod progress; pub(crate) mod tokenizer; diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs index 7315f27b5f..35f3aa11fd 100644 --- a/mistralrs-core/src/utils/model_config.rs +++ b/mistralrs-core/src/utils/model_config.rs @@ -1,10 +1,11 @@ use super::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use anyhow::Result; -use candle_core::quantized::{ggml_file, gguf_file}; +use candle_core::quantized::ggml_file; use candle_nn::VarBuilder; use std::{collections::HashMap, path::PathBuf}; use crate::{ + gguf::Content, lora::{LoraConfig, Ordering}, pipeline::ModelPaths, xlora_models::XLoraConfig, @@ -17,12 +18,6 @@ pub struct FileGGML { pub gqa: usize, } -#[derive(derive_more::From)] -pub struct FileGGUF<'a> { - pub ct: gguf_file::Content, - pub reader: &'a mut std::fs::File, -} - #[derive(derive_more::From)] pub struct Device<'a> { pub device: &'a candle_core::Device, @@ -94,7 +89,7 @@ impl<'a> Adapter<'a> { // New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods: pub struct ParamsGGML(pub FileGGML); -pub struct ParamsGGUF<'a>(pub FileGGUF<'a>, pub Device<'a>); +pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(pub Content<'a, R>, pub Device<'a>); // A `None` type vs the `Some` type (`Adapter<'a>`) pub struct NoAdapter {} @@ -103,7 +98,7 @@ pub struct NoAdapter {} // (required workaround to support impl on subtypes, otherwise would use an enum) pub trait QuantParams {} impl QuantParams for ParamsGGML {} -impl QuantParams for ParamsGGUF<'_> {} +impl QuantParams for ParamsGGUF<'_, R> {} // Emulates `Option` but is compatible as a type bound in `impl` for Some vs None pub trait MaybeAdapter {} @@ -155,8 +150,7 @@ pub trait FromGGML { pub trait FromGGUF { fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + ct: Content<'_, R>, device: &candle_core::Device, mapper: DeviceMapMetadata, ) -> Result @@ -181,8 +175,7 @@ pub trait FromAdapterGGML { pub trait FromAdapterGGUF { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + ct: Content<'_, R>, device: &candle_core::Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -232,20 +225,20 @@ impl Config> { } } -impl Config, NoAdapter> { +impl Config, NoAdapter> { pub fn try_into_model(self) -> Result { // Destructure props: - let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + let ParamsGGUF(content, Device { device, mapper }) = self.quant; // Forwards all structured fields above into the required flattened param sequence: - T::from_gguf(ct, reader, device, mapper) + T::from_gguf(content, device, mapper) } } -impl Config, Adapter<'_>> { +impl Config, Adapter<'_>> { pub fn try_into_model(self) -> Result { // Destructure props: - let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + let ParamsGGUF(content, Device { device, mapper }) = self.quant; let Adapter { xlora_config, @@ -257,8 +250,7 @@ impl Config, Adapter<'_>> { // Forwards all structured fields above into the required flattened param sequence: T::from_gguf( - ct, - reader, + content, device, lora_config, &vb, @@ -299,10 +291,10 @@ impl TryFrom> for XLoraQLlama { akin! { let &models_gguf = [QLlama, QPhi, QPhi3]; - impl TryFrom>> for *models_gguf { + impl TryFrom>> for *models_gguf { type Error = candle_core::Error; - fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_quantized("`Config` should be GGUF Quantized"); config.try_into_model() } @@ -312,10 +304,10 @@ akin! { akin! { let &models_gguf_a = [XLoraQLlama, XLoraQPhi3]; - impl TryFrom>> for *models_gguf_a { + impl TryFrom>> for *models_gguf_a { type Error = candle_core::Error; - fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter"); config.try_into_model() } diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index 8eef06eeb2..dc98c64193 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -2,12 +2,12 @@ use std::collections::HashMap; +use crate::gguf::Content; use crate::lora::{ get_lora_cfg, AdapterSwapper, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear, }; -use crate::utils::max_seq_len::get_gguf_max_seq_len; +use candle_core::quantized::ggml_file; use candle_core::quantized::QMatMul; -use candle_core::quantized::{ggml_file, gguf_file}; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, RotaryEmbedding, VarBuilder}; use tqdm::Iter; @@ -447,8 +447,7 @@ impl ModelConfig::FromAdapterGGML for ModelWeights { impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -457,46 +456,48 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { mapper: DeviceMapMetadata, preload_adapters: &Option>, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; verify_sanity_gguf( - md_get("general.architecture")?.to_string().unwrap(), + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), "llama", )?; verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?; // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") + let n_expert = ct + .get_metadata("llama.expert_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") + let n_expert_used = ct + .get_metadata("llama.expert_used_count") .and_then(|v| v.to_u32()) .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + let head_count = ct.get_metadata("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("llama.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("llama.rope.dimension_count")?.to_u32()? as usize; // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + let rms_norm_eps = ct + .get_metadata("llama.attention.layer_norm_rms_epsilon")? + .to_f32()?; - let rope_freq_base = md_get("llama.rope.freq_base") + let rope_freq_base = ct + .get_metadata("llama.rope.freq_base") .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let head_dim = embedding_length / head_count; - let max_seq_len = - get_gguf_max_seq_len(md_get("llama.context_length"), MAX_SEQ_LEN as u64) as usize; + let max_seq_len = ct + .get_metadata("llama.context_length")? + .to_u64() + .unwrap_or(MAX_SEQ_LEN as u64) as usize; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = QRmsNorm::new( - ct.tensor(reader, "output_norm.weight", device)?, - rms_norm_eps, - )?; - let output = ct.tensor(reader, "output.weight", device)?; + let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?; + let output = ct.tensor("output.weight", device)?; let mut layers = Vec::with_capacity(block_count); let mut count = 0; let mapper = mapper.into_mapper(block_count, device)?; @@ -513,18 +514,14 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { DType::F32, )?; - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wo = - ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let mlp_or_moe = if n_expert <= 1 { - let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; - let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; let cfg_w1 = get_lora_cfg(&feed_forward_w1); let cfg_w2 = get_lora_cfg(&feed_forward_w2); let cfg_w3 = get_lora_cfg(&feed_forward_w3); @@ -562,15 +559,15 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { }) } else { let feed_forward_gate_inp = - ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?; let mut experts = Vec::with_capacity(n_expert); for i in 0..n_expert { let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?; let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?; let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?; let cfg_w1 = get_lora_cfg(&feed_forward_w1); let cfg_w2 = get_lora_cfg(&feed_forward_w2); let cfg_w3 = get_lora_cfg(&feed_forward_w3); @@ -613,9 +610,8 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { experts, } }; - let attention_norm = - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?; let cfgq = get_lora_cfg(&attention_wq); let cfgk = get_lora_cfg(&attention_wk); let cfgv = get_lora_cfg(&attention_wv); diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 248bc4175f..44feddc344 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use crate::device_map::DeviceMapper; +use crate::gguf::Content; use crate::layers::repeat_kv; use crate::layers::verify_sanity_gguf; use crate::layers::CausalMasker; @@ -18,7 +19,6 @@ use crate::lora::Ordering; use crate::lora::QLoraLinear; use crate::pipeline::extract_logits; use crate::DeviceMapMetadata; -use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; use candle_core::quantized::QTensor; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; @@ -216,8 +216,7 @@ fn precomput_freqs_cis( impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] fn from_gguf( - ct: gguf_file::Content, - reader: &mut R, + mut ct: Content<'_, R>, device: &Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, @@ -226,36 +225,39 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { mapper: DeviceMapMetadata, preload_adapters: &Option>, ) -> Result { - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle_core::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - verify_sanity_gguf(md_get("general.architecture")?.to_string().unwrap(), "phi3")?; + verify_sanity_gguf( + ct.get_metadata("general.architecture")? + .to_string() + .unwrap(), + "phi3", + )?; verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?; // Parameter extraction from metadata. - let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("phi3.block_count")?.to_u32()? as usize; - let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; - let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; - let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; - let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let context_window = md_get("phi3.context_length")?.to_u32()? as usize; + let head_count = ct.get_metadata("phi3.attention.head_count")?.to_u32()? as usize; + let head_count_kv = ct.get_metadata("phi3.attention.head_count_kv")?.to_u32()? as usize; + let block_count = ct.get_metadata("phi3.block_count")?.to_u32()? as usize; + let embedding_length = ct.get_metadata("phi3.embedding_length")?.to_u32()? as usize; + let i_size = ct.get_metadata("phi3.feed_forward_length")?.to_u32()? as usize; + let rope_dim = ct.get_metadata("phi3.rope.dimension_count")?.to_u32()? as usize; + let rms_eps = ct + .get_metadata("phi3.attention.layer_norm_rms_epsilon")? + .to_f32()? as f64; + let context_window = ct.get_metadata("phi3.context_length")?.to_u32()? as usize; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; - let output = QMatMul::from_qtensor(ct.tensor(reader, "output.weight", device)?)?; + let output_norm = rms_norm(ct.tensor("output_norm.weight", device)?, rms_eps)?; + let output = QMatMul::from_qtensor(ct.tensor("output.weight", device)?)?; let mut layers = Vec::with_capacity(block_count); let mapper = mapper.into_mapper(block_count, device)?; let mut count = 0; for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let device = mapper.device_for(layer_idx, false).unwrap_or(device); - let ffn_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; - let ffn_down = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let ffn_up = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?; + let ffn_down = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?; let cfg_up = get_lora_cfg(&ffn_up); let cfg_down = get_lora_cfg(&ffn_down); let mlp = Mlp { @@ -282,15 +284,15 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { i_size, }; let attn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?, rms_eps, )?; let ffn_norm = rms_norm( - ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?, + ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?, rms_eps, )?; - let qkv = ct.tensor(reader, &format!("{prefix}.attn_qkv.weight"), device)?; - let output = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let qkv = ct.tensor(&format!("{prefix}.attn_qkv.weight"), device)?; + let output = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?; let cfg_qkv = get_lora_cfg(&qkv); let cfg_out = get_lora_cfg(&output); layers.push(LayerWeights { diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index 359ac00e80..ebf49b824b 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -9,7 +9,7 @@ These are API docs for the `mistralrs` package. ## `Which` -Each `*_model_id` may be a HF hub repo or a local path. +Each `*_model_id` may be a HF hub repo or a local path. For quantized GGUF models, a list is accepted if multiples files must be specified. Additionally, for models without quantization, the model architecture should be provided as the `arch` parameter in contrast to GGUF models which encode the architecture in the file. It should be one of the following: - `mistral` @@ -49,13 +49,13 @@ class Which(Enum): class GGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] repeat_last_n: int = 64 @dataclass class XLoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] xlora_model_id: str order: str tgt_non_granular_index: int | None = None @@ -64,7 +64,7 @@ class Which(Enum): class LoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] adapters_model_id: str order: str repeat_last_n: int = 64 diff --git a/mistralrs-pyo3/README.md b/mistralrs-pyo3/README.md index 1d57af08bd..5e058da37c 100644 --- a/mistralrs-pyo3/README.md +++ b/mistralrs-pyo3/README.md @@ -115,18 +115,18 @@ We also provide [a cookbook here](../examples/python/cookbook.ipynb)! ## Example ```python -from mistralrs import ModelKind, MistralLoader, ChatCompletionRequest - -kind = ModelKind.QuantizedGGUF -loader = MistralLoader( - model_id="mistralai/Mistral-7B-Instruct-v0.1", - kind=kind, - no_kv_cache=False, - repeat_last_n=64, - quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", - quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", +from mistralrs import Runner, Which, ChatCompletionRequest + +runner = Runner( + which=Which.GGUF( + tok_model_id="mistralai/Mistral-7B-Instruct-v0.1", + quantized_model_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", + quantized_filename="mistral-7b-instruct-v0.1.Q4_K_M.gguf", + tokenizer_json=None, + repeat_last_n=64, + ) ) -runner = loader.load() + res = runner.send_chat_completion_request( ChatCompletionRequest( model="mistral", @@ -134,10 +134,12 @@ res = runner.send_chat_completion_request( {"role": "user", "content": "Tell me a story about the Rust type system."} ], max_tokens=256, - frequency_penalty=1.0, + presence_penalty=1.0, top_p=0.1, temperature=0.1, ) ) -print(res) +print(res.choices[0].message.content) +print(res.usage) + ``` \ No newline at end of file diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index f1d7c46c7c..a307bfd6fc 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -95,13 +95,13 @@ class Which(Enum): class GGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] repeat_last_n: int = 64 @dataclass class XLoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] xlora_model_id: str order: str tgt_non_granular_index: int | None = None @@ -110,7 +110,7 @@ class Which(Enum): class LoraGGUF: tok_model_id: str quantized_model_id: str - quantized_filename: str + quantized_filename: str | list[str] adapters_model_id: str order: str repeat_last_n: int = 64 diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 2f15d8ac58..1c9ff39166 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -177,7 +177,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .build(), Which::XLoraGGUF { @@ -195,7 +198,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .with_xlora( xlora_model_id, @@ -222,7 +228,10 @@ fn parse_which( chat_template, tok_model_id, quantized_model_id, - quantized_filename, + match quantized_filename { + Either::Left(l) => vec![l], + Either::Right(r) => r, + }, ) .with_lora( adapters_model_id, diff --git a/mistralrs-pyo3/src/which.rs b/mistralrs-pyo3/src/which.rs index a5a33a6123..0b121d6083 100644 --- a/mistralrs-pyo3/src/which.rs +++ b/mistralrs-pyo3/src/which.rs @@ -1,3 +1,4 @@ +use either::Either; use mistralrs_core::NormalLoaderType; use pyo3::pyclass; @@ -58,14 +59,14 @@ pub enum Which { GGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, }, XLoraGGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, xlora_model_id: String, order: String, @@ -75,7 +76,7 @@ pub enum Which { LoraGGUF { tok_model_id: Option, quantized_model_id: String, - quantized_filename: String, + quantized_filename: Either>, repeat_last_n: Option, adapters_model_id: String, order: String, diff --git a/mistralrs/examples/gguf_locally/main.rs b/mistralrs/examples/gguf_locally/main.rs index b04fc9fa53..e38e7e8c40 100644 --- a/mistralrs/examples/gguf_locally/main.rs +++ b/mistralrs/examples/gguf_locally/main.rs @@ -17,7 +17,7 @@ fn setup() -> anyhow::Result> { Some("chat_templates/mistral.json".to_string()), None, ".".to_string(), - "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()], ) .build(); // Load, into a Pipeline diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index b6539edaf2..076ad99788 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -15,7 +15,7 @@ fn setup() -> anyhow::Result> { None, Some("mistralai/Mistral-7B-Instruct-v0.1".to_string()), "TheBloke/Mistral-7B-Instruct-v0.1-GGUF".to_string(), - "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + vec!["mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string()], ) .build(); // Load, into a Pipeline