diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 9d4c9c1999..9d967e15c1 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -55,6 +55,9 @@ once_cell.workspace = true toml = "0.8.12" strum = { version = "0.26", features = ["derive"] } derive_more = { version = "0.99.17", default-features = false, features = ["from"] } +akin = "0.4.0" +variantly = "0.4.0" +buildstructor = "0.5.4" tracing-subscriber.workspace = true reqwest = { version = "0.12.4", features = ["blocking"] } diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index e902b3cd13..70b36e38b3 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -10,6 +10,7 @@ use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention, }; use crate::pipeline::{extract_logits, Cache}; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; const MAX_SEQ_LEN: u32 = 4096; @@ -194,8 +195,8 @@ pub struct ModelWeights { mapper: Option>, } -impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { +impl ModelConfig::FromGGML for ModelWeights { + fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let rotary = RotaryEmbedding::new_partial( 10000., @@ -254,8 +255,10 @@ impl ModelWeights { mapper: None, }) } +} - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -383,7 +386,9 @@ impl ModelWeights { mapper: Some(mapper), }) } +} +impl ModelWeights { pub fn forward( &mut self, x: &Tensor, diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index a002f0b898..0c4d3225e3 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -9,6 +9,7 @@ use crate::device_map::DeviceMapper; use crate::layers::ScaledDotProductAttention; use crate::layers::{repeat_kv, CausalMasker, QLinear}; use crate::pipeline::{extract_logits, Cache}; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; pub const MAX_SEQ_LEN: usize = 4096; @@ -141,8 +142,8 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { Ok(ln) } -impl ModelWeights { - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -211,7 +212,9 @@ impl ModelWeights { mapper, }) } +} +impl ModelWeights { pub fn forward( &mut self, input_ids: &Tensor, diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 909dd90c2b..8149eea845 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -5,6 +5,7 @@ use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention, }; use crate::pipeline::Cache; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; @@ -159,8 +160,8 @@ fn precomput_freqs_cis( Ok((cos, sin)) } -impl ModelWeights { - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -248,7 +249,9 @@ impl ModelWeights { max_seq_len: context_window, }) } +} +impl ModelWeights { pub fn forward(&mut self, input_ids: &Tensor, seqlen_offsets: &[usize]) -> Result { let (_b_sz, seq_len) = input_ids.dims2()?; let mut xs = self.tok_embeddings.forward(input_ids)?; diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 9114f70c98..aaeaea4730 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -1,7 +1,7 @@ use super::cache_manager::DefaultCacheManager; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, Pipeline, QuantizationKind, TokenSource, XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -11,8 +11,8 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; -use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG}; use crate::{ @@ -96,12 +96,16 @@ impl GGMLLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Ggml, + }; + Self { config, chat_template, tokenizer_json, model_id, - kind: ModelKind::QuantizedGGML, + kind, quantized_filename, quantized_model_id, ..Default::default() @@ -138,7 +142,8 @@ impl GGMLLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraGGML; + self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into(); + self.with_adapter( xlora_model_id, xlora_order, @@ -148,7 +153,8 @@ impl GGMLLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraGGML; + self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into(); + self.with_adapter(lora_model_id, lora_order, false, None) } @@ -236,7 +242,7 @@ impl Loader for GGMLLoader { if in_situ_quant.is_some() { anyhow::bail!( - "You are trying to in-situ quantize a GGUF model. This will not do anything." + "You are trying to in-situ quantize a GGML model. This will not do anything." ); } if !mapper.is_dummy() { @@ -267,69 +273,33 @@ impl Loader for GGMLLoader { info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`."); } - let mut is_lora = false; - let model = match self.kind { - ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml(model, self.config.gqa)?), - ModelKind::XLoraGGML => { - let vb = from_mmaped_safetensors( - vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - Model::XLoraLlama(XLoraQLlama::from_ggml( - model, - self.config.gqa, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?) + let has_adapter = self.kind.is_adapted(); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); + + let model_config = { + // Base config (quantization only): + let quant = ModelConfig::ParamsGGML((model, self.config.gqa).into()); + + // With optional adapter config: + let mut adapter = None; + if has_adapter { + adapter.replace(ModelConfig::Adapter::try_new( + paths, device, silent, is_xlora, + )?); } - ModelKind::LoraGGML => { - is_lora = true; - let vb = from_mmaped_safetensors( - vec![], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - Model::XLoraLlama(XLoraQLlama::from_ggml( - model, - self.config.gqa, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?) + + ModelConfig::ModelParams::builder() + .quant(quant) + .and_adapter(adapter) + .build() + }; + + // Config into model: + // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched + let model = match self.kind { + ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?), + ModelKind::AdapterQuantized { .. } => { + Model::XLoraLlama(XLoraQLlama::try_from(model_config)?) } _ => unreachable!(), }; @@ -345,10 +315,6 @@ impl Loader for GGMLLoader { Model::XLoraLlama(ref xl) => xl.max_seq_len, }; let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = match &model { - Model::Llama(_) => false, - Model::XLoraLlama(_) => !is_lora, - }; let num_hidden_layers = match model { Model::Llama(ref model) => model.cache.lock().len(), Model::XLoraLlama(ref model) => model.cache.lock().len(), @@ -372,10 +338,10 @@ impl Loader for GGMLLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } @@ -508,14 +474,16 @@ impl Pipeline for GGMLPipeline { } } fn activate_adapters(&mut self, adapter_names: Vec) -> anyhow::Result { - if !self.metadata.is_lora { - anyhow::bail!("Cannot activate adapters non-LoRA models.") + let is_lora = self.metadata.kind.is_adapted_and(|a| a.is_lora()); + if !is_lora { + anyhow::bail!("Activating adapters is only supported for models fine-tuned with LoRA.") } + match self.model { - Model::Llama(_) => unreachable!(), Model::XLoraLlama(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), + _ => unreachable!(), } } } diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 71520b1d6d..abe2b2c480 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -1,7 +1,8 @@ use super::cache_manager::DefaultCacheManager; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, Pipeline, PrettyName, QuantizationKind, TokenSource, + XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -12,8 +13,8 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; -use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG}; use crate::{ @@ -134,11 +135,15 @@ impl GGUFLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Gguf, + }; + Self { config, chat_template, model_id: tok_model_id, - kind: ModelKind::QuantizedGGUF, + kind, quantized_filename, quantized_model_id, ..Default::default() @@ -175,7 +180,8 @@ impl GGUFLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraGGUF; + self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into(); + self.with_adapter( xlora_model_id, xlora_order, @@ -185,7 +191,8 @@ impl GGUFLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraGGUF; + self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into(); + self.with_adapter(lora_model_id, lora_order, false, None) } @@ -376,123 +383,43 @@ impl Loader for GGUFLoader { } }; - let mut is_lora = false; + let has_adapter = self.kind.is_adapted(); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); + + let model_config = { + // Base config (quantization only): + let quant = ModelConfig::ParamsGGUF((model, &mut file).into(), (device, mapper).into()); + + // With optional adapter config: + let mut adapter = None; + if has_adapter { + adapter.replace(ModelConfig::Adapter::try_new( + paths, device, silent, is_xlora, + )?); + } + + ModelConfig::ModelParams::builder() + .quant(quant) + .and_adapter(adapter) + .build() + }; + + // Config into model: let model = match self.kind { - ModelKind::QuantizedGGUF => match arch { - GGUFArchitecture::Llama => { - Model::Llama(QLlama::from_gguf(model, &mut file, device, mapper)?) - } - GGUFArchitecture::Phi2 => { - Model::Phi2(QPhi::from_gguf(model, &mut file, device, mapper)?) - } - GGUFArchitecture::Phi3 => { - Model::Phi3(QPhi3::from_gguf(model, &mut file, device, mapper)?) - } - a => bail!("Unsupported architecture `{a:?}`"), + ModelKind::Quantized { .. } => match arch { + GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?), + GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?), + GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?), + a => bail!("Unsupported architecture `{a:?}` for GGUF"), + }, + ModelKind::AdapterQuantized { adapter, .. } => match arch { + GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?), + GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?), + a => bail!( + "Unsupported architecture `{a:?}` for GGUF {kind}", + kind = adapter.pretty_name() + ), }, - ModelKind::XLoraGGUF => { - let vb = from_mmaped_safetensors( - vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - match arch { - GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - a => bail!("Unsupported architecture for GGUF X-LoRA `{a:?}`"), - } - } - ModelKind::LoraGGUF => { - is_lora = true; - let vb = from_mmaped_safetensors( - vec![], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - match arch { - GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - a => bail!("Unsupported architecture for GGUF LoRA `{a:?}`"), - } - } _ => unreachable!(), }; @@ -509,10 +436,6 @@ impl Loader for GGUFLoader { Model::XLoraPhi3(ref p) => p.max_seq_len, }; let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = match &model { - Model::Llama(_) | Model::Phi2(_) | Model::Phi3(_) => false, - Model::XLoraLlama(_) | Model::XLoraPhi3(_) => !is_lora, - }; let num_hidden_layers = match model { Model::Llama(ref model) => model.cache.lock().len(), Model::Phi2(ref model) => model.cache.lock().len(), @@ -553,10 +476,10 @@ impl Loader for GGUFLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } @@ -685,19 +608,19 @@ impl Pipeline for GGUFPipeline { } } fn activate_adapters(&mut self, adapter_names: Vec) -> anyhow::Result { - if !self.metadata.is_lora { - anyhow::bail!("Cannot activate adapters non-LoRA models.") + let is_lora = self.metadata.kind.is_adapted_and(|a| a.is_lora()); + if !is_lora { + anyhow::bail!("Activating adapters is only supported for models fine-tuned with LoRA.") } + match self.model { - Model::Llama(_) => unreachable!(), - Model::Phi2(_) => unreachable!(), - Model::Phi3(_) => unreachable!(), Model::XLoraLlama(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), Model::XLoraPhi3(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), + _ => unreachable!(), } } } diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 5dae166a37..b3c07b843b 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -37,7 +37,7 @@ use rand_isaac::Isaac64Rng; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use serde_json::Value; pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline}; -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::path::Path; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -226,31 +226,12 @@ impl fmt::Display for TokenSource { } } -#[derive(Clone, Default)] /// The kind of model to build. +#[derive(Clone, Default, derive_more::From, strum::Display)] pub enum ModelKind { - #[default] - Normal, - XLoraNormal, - XLoraGGUF, - XLoraGGML, - QuantizedGGUF, - QuantizedGGML, - LoraGGUF, - LoraGGML, - LoraNormal, - Speculative { - target: Box, - draft: Box, - }, -} - -// TODO: Future replacement for `ModelKind` above: -#[derive(Default, derive_more::From, strum::Display)] -pub enum ModelKindB { #[default] #[strum(to_string = "normal (no quant, no adapters)")] - Plain, + Normal, #[strum(to_string = "quantized from {quant} (no adapters)")] Quantized { quant: QuantizationKind }, @@ -264,8 +245,6 @@ pub enum ModelKindB { quant: QuantizationKind, }, - // TODO: This would need to be later changed to reference `Self`, but this current way - // avoids having to handle the conversion logic with `ModelKind`. #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")] Speculative { target: Box, @@ -273,21 +252,40 @@ pub enum ModelKindB { }, } -#[derive(Clone, Copy, strum::Display, strum::EnumIs)] +#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] #[strum(serialize_all = "kebab-case")] pub enum QuantizationKind { + /// GGML Ggml, + /// GGUF Gguf, } -#[derive(Clone, Copy, strum::Display, strum::EnumIs)] +#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] #[strum(serialize_all = "kebab-case")] pub enum AdapterKind { + /// LoRA Lora, + /// X-LoRA XLora, } -impl ModelKindB { +// For the proper name as formatted via doc comment for a variant +pub trait PrettyName: strum::EnumMessage + ToString { + fn pretty_name(&self) -> String { + match self.get_documentation() { + Some(s) => s.to_string(), + // Instead of panic via expect(), + // fallback to default kebab-case: + None => self.to_string(), + } + } +} + +impl PrettyName for AdapterKind {} +impl PrettyName for QuantizationKind {} + +impl ModelKind { // Quantized helpers: pub fn is_quantized(&self) -> bool { self.quantized_kind().iter().any(|q| q.is_some()) @@ -298,14 +296,14 @@ impl ModelKindB { } pub fn quantized_kind(&self) -> Vec> { - use ModelKindB::*; + use ModelKind::*; match self { - Plain | Adapter { .. } => vec![None], + Normal | Adapter { .. } => vec![None], Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)], Speculative { target, draft } => { - let t = ModelKindB::from(*target.clone()); - let d = ModelKindB::from(*draft.clone()); + let t = *target.clone(); + let d = *draft.clone(); [t.quantized_kind(), d.quantized_kind()].concat() } @@ -322,14 +320,14 @@ impl ModelKindB { } pub fn adapted_kind(&self) -> Vec> { - use ModelKindB::*; + use ModelKind::*; match self { - Plain | Quantized { .. } => vec![None], + Normal | Quantized { .. } => vec![None], Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)], Speculative { target, draft } => { - let t = ModelKindB::from(*target.clone()); - let d = ModelKindB::from(*draft.clone()); + let t = *target.clone(); + let d = *draft.clone(); [t.adapted_kind(), d.adapted_kind()].concat() } @@ -337,65 +335,6 @@ impl ModelKindB { } } -// TODO: Temporary compatibility layers follow (until a future PR follow-up introduces a breaking change) -impl Display for ModelKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", ModelKindB::from(self.clone())) - } -} - -// Delegate to `ModelKindB` methods: -impl ModelKind { - // Quantized helpers: - pub fn is_quantized(&self) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_quantized() - } - - pub fn is_quantized_and(&self, f: impl FnMut(QuantizationKind) -> bool) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_quantized_and(f) - } - - pub fn quantized_kind(&self) -> Vec> { - let k = ModelKindB::from(self.clone()); - k.quantized_kind() - } - - // Adapter helpers: - pub fn is_adapted(&self) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_adapted() - } - - pub fn is_adapted_and(&self, f: impl FnMut(AdapterKind) -> bool) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_adapted_and(f) - } - - pub fn adapted_kind(&self) -> Vec> { - let k = ModelKindB::from(self.clone()); - k.adapted_kind() - } -} - -impl From for ModelKindB { - fn from(kind: ModelKind) -> Self { - match kind { - ModelKind::Normal => ModelKindB::Plain, - ModelKind::QuantizedGGML => (QuantizationKind::Ggml).into(), - ModelKind::QuantizedGGUF => (QuantizationKind::Gguf).into(), - ModelKind::XLoraNormal => (AdapterKind::XLora).into(), - ModelKind::XLoraGGML => (AdapterKind::XLora, QuantizationKind::Ggml).into(), - ModelKind::XLoraGGUF => (AdapterKind::XLora, QuantizationKind::Gguf).into(), - ModelKind::LoraNormal => (AdapterKind::Lora).into(), - ModelKind::LoraGGML => (AdapterKind::Lora, QuantizationKind::Ggml).into(), - ModelKind::LoraGGUF => (AdapterKind::Lora, QuantizationKind::Gguf).into(), - ModelKind::Speculative { target, draft } => (target, draft).into(), - } - } -} - /// The `Loader` trait abstracts the loading process. The primary entrypoint is the /// `load_model` method. /// @@ -458,10 +397,11 @@ pub struct GeneralMetadata { pub repeat_last_n: usize, pub tok_trie: Arc, pub has_no_kv_cache: bool, - pub is_xlora: bool, pub num_hidden_layers: usize, pub eos_tok: Vec, - pub is_lora: bool, + pub kind: ModelKind, + // TODO: Replace is_xlora queries to check via kind instead: + pub is_xlora: bool, } pub enum AdapterInstruction { @@ -792,7 +732,10 @@ pub trait NormalModel { Ok(()) } fn activate_adapters(&mut self, _: Vec) -> candle_core::Result { - candle_core::bail!("Unable to activate adapters for model without adapters"); + // NOTE: While X-LoRA shares a similar name, it is not equivalent. Its adapter set must remain the same. + candle_core::bail!( + "Activating adapters is only supported for models fine-tuned with LoRA." + ); } } diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index bf36cb9016..c97e194b4e 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -4,8 +4,9 @@ use super::loaders::{ Phi3Loader, Qwen2Loader, }; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, NormalModel, NormalModelLoader, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, NormalModel, NormalModelLoader, Pipeline, TokenSource, + XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -130,7 +131,9 @@ impl NormalLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraNormal; + self.kind = ModelKind::Adapter { + adapter: AdapterKind::XLora, + }; self.with_adapter( xlora_model_id, xlora_order, @@ -140,7 +143,9 @@ impl NormalLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraNormal; + self.kind = ModelKind::Adapter { + adapter: AdapterKind::Lora, + }; self.with_adapter(lora_model_id, lora_order, false, None) } @@ -242,7 +247,7 @@ impl Loader for NormalLoader { Device::Cpu }; - let is_lora = self.kind.is_adapted_and(|a| a.is_lora()); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); let mut model = match self.kind { ModelKind::Normal => normal_model_loader!( @@ -258,7 +263,9 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::XLoraNormal => xlora_model_loader!( + ModelKind::Adapter { + adapter: AdapterKind::XLora, + } => xlora_model_loader!( paths, dtype, default_dtype, @@ -271,7 +278,9 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::LoraNormal => lora_model_loader!( + ModelKind::Adapter { + adapter: AdapterKind::Lora, + } => lora_model_loader!( paths, dtype, default_dtype, @@ -284,13 +293,7 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::QuantizedGGUF - | ModelKind::QuantizedGGML - | ModelKind::XLoraGGUF - | ModelKind::XLoraGGML - | ModelKind::LoraGGUF - | ModelKind::LoraGGML - | ModelKind::Speculative { .. } => unreachable!(), + _ => unreachable!(), }; let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?; @@ -305,7 +308,6 @@ impl Loader for NormalLoader { let max_seq_len = model.max_seq_len(); let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = model.is_xlora() && !is_lora; let num_hidden_layers = model.cache().lock().len(); let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer); Ok(Arc::new(Mutex::new(NormalPipeline { @@ -326,10 +328,10 @@ impl Loader for NormalLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 72fb055d4e..c07f63e3f9 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod model_config; pub(crate) mod tokenizer; pub(crate) mod tokens; pub(crate) mod varbuilder_utils; diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs new file mode 100644 index 0000000000..7315f27b5f --- /dev/null +++ b/mistralrs-core/src/utils/model_config.rs @@ -0,0 +1,323 @@ +use super::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; +use anyhow::Result; +use candle_core::quantized::{ggml_file, gguf_file}; +use candle_nn::VarBuilder; +use std::{collections::HashMap, path::PathBuf}; + +use crate::{ + lora::{LoraConfig, Ordering}, + pipeline::ModelPaths, + xlora_models::XLoraConfig, + DeviceMapMetadata, +}; + +#[derive(derive_more::From)] +pub struct FileGGML { + pub ct: ggml_file::Content, + pub gqa: usize, +} + +#[derive(derive_more::From)] +pub struct FileGGUF<'a> { + pub ct: gguf_file::Content, + pub reader: &'a mut std::fs::File, +} + +#[derive(derive_more::From)] +pub struct Device<'a> { + pub device: &'a candle_core::Device, + pub mapper: DeviceMapMetadata, +} + +pub struct Adapter<'a> { + pub xlora_config: Option, + pub lora_config: &'a [((String, String), LoraConfig)], + pub vb: VarBuilder<'a>, + pub ordering: &'a Ordering, + pub preload_adapters: Option, LoraConfig)>>, +} + +impl<'a> Adapter<'a> { + // NOTE: It is not possible to store references for values returned by: load_preload_adapters() + from_mmaped_safetensors(), + // As referenced value would drop after this method, Adapter takes ownership of vb + preload_adapters + // and then passes by reference to the `from_gguf()` / `from_ggml()` methods when proxying to params. + // NOTE: Due to reference usage persisting in returned struct, additional lifetime annotations were required. + #[allow(clippy::borrowed_box)] + pub fn try_new<'b: 'a>( + paths: &'b Box, + device: &'b candle_core::Device, + silent: bool, + is_xlora: bool, + ) -> Result { + let lora_config = paths.get_adapter_configs().as_ref().unwrap(); + let ordering = paths.get_ordering().as_ref().unwrap(); + let preload_adapters = load_preload_adapters( + paths.get_lora_preload_adapter_info(), + candle_core::DType::F32, + device, + silent, + )?; + + // X-LoRA support: + let mut xlora_paths: Vec = vec![]; + let mut xlora_config: Option = None; + if is_xlora { + xlora_paths = vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()]; + xlora_config = Some(paths.get_classifier_config().as_ref().unwrap().clone()); + } + + // Create VarBuilder: + // TODO: `from_mmaped_safetensors` has `xlora_paths` as the 2nd param (_valid but params need to be named better_) + let vb = from_mmaped_safetensors( + xlora_paths, + paths + .get_adapter_filenames() + .as_ref() + .unwrap() + .iter() + .map(|(_, x)| (*x).to_owned()) + .collect::>(), + candle_core::DType::F32, + device, + silent, + )?; + + Ok(Self { + lora_config, + xlora_config, + vb, + ordering, + preload_adapters, + }) + } +} + +// New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods: +pub struct ParamsGGML(pub FileGGML); +pub struct ParamsGGUF<'a>(pub FileGGUF<'a>, pub Device<'a>); + +// A `None` type vs the `Some` type (`Adapter<'a>`) +pub struct NoAdapter {} + +// Marker traits to restrict type input: +// (required workaround to support impl on subtypes, otherwise would use an enum) +pub trait QuantParams {} +impl QuantParams for ParamsGGML {} +impl QuantParams for ParamsGGUF<'_> {} + +// Emulates `Option` but is compatible as a type bound in `impl` for Some vs None +pub trait MaybeAdapter {} +impl MaybeAdapter for Adapter<'_> {} +impl MaybeAdapter for NoAdapter {} + +// `derive_more::From` provides a terser construction for enum variants of `ModelParams`. +#[derive(derive_more::From)] +pub struct Config { + pub quant: Q, + pub adapter: A, +} + +// NOTE: Variantly used for `.expect_quantized()` / `.expect_adapted()` methods +// `where` clause required due to bug with inline bounds: +// https://github.com/luker-os/variantly/pull/16 +#[allow(clippy::large_enum_variant)] +#[derive(variantly::Variantly)] +pub enum ModelParams<'a, Q> +where + Q: QuantParams, +{ + Quantized(Config), + Adapted(Config>), +} + +// A `builder()` method is derived from the `new()` method and it's params (derived builder struct fields). +// NOTE: Intended to be built via fluent API in a single line, cannot conditionally append params. +// `.adapter(Adapter<' >)` or for conditional usage `.and_adapter(Option)` can be used. +// Otherwise omitting an `.adapter()` call prior to calling `build()` is ok, defaults to `None`. +#[buildstructor::buildstructor] +impl<'a, Q: QuantParams> ModelParams<'a, Q> { + #[builder] + pub fn new<'b: 'a>(quant: Q, adapter: Option>) -> Self { + match adapter { + None => Self::Quantized((quant, NoAdapter {}).into()), + Some(a) => Self::Adapted((quant, a).into()), + } + } +} + +// Traits for the existing methods used across various model types to impl `from_ggml()` / `from_gguf()` +// Basic: +pub trait FromGGML { + fn from_ggml(ct: ggml_file::Content, gqa: usize) -> Result + where + Self: Sized; +} + +pub trait FromGGUF { + fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &candle_core::Device, + mapper: DeviceMapMetadata, + ) -> Result + where + Self: Sized; +} + +// Extended variants: +pub trait FromAdapterGGML { + fn from_ggml( + ct: ggml_file::Content, + gqa: usize, + lora_config: &[((String, String), LoraConfig)], + vb: &VarBuilder, + ordering: &Ordering, + xlora_config: Option, + preload_adapters: &Option>, + ) -> Result + where + Self: Sized; +} +pub trait FromAdapterGGUF { + #[allow(clippy::too_many_arguments)] + fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &candle_core::Device, + lora_config: &[((String, String), LoraConfig)], + vb: &VarBuilder, + ordering: &Ordering, + xlora_config: Option, + mapper: DeviceMapMetadata, + preload_adapters: &Option>, + ) -> Result + where + Self: Sized; +} + +// NOTE: Below is a workaround to proxy params to the existing API methods `get_gguf()` / `get_gmml()` traits covered above. +impl Config { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGML(FileGGML { ct, gqa }) = self.quant; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_ggml(ct, gqa) + } +} + +impl Config> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGML(FileGGML { ct, gqa }) = self.quant; + + let Adapter { + xlora_config, + lora_config, + vb, + ordering, + preload_adapters, + } = self.adapter; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_ggml( + ct, + gqa, + lora_config, + &vb, + ordering, + xlora_config, + &preload_adapters, + ) + } +} + +impl Config, NoAdapter> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_gguf(ct, reader, device, mapper) + } +} + +impl Config, Adapter<'_>> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + + let Adapter { + xlora_config, + lora_config, + vb, + ordering, + preload_adapters, + } = self.adapter; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_gguf( + ct, + reader, + device, + lora_config, + &vb, + ordering, + xlora_config, + mapper, + &preload_adapters, + ) + } +} + +use crate::{ + models::quantized_llama::ModelWeights as QLlama, + models::quantized_phi2::ModelWeights as QPhi, + models::quantized_phi3::ModelWeights as QPhi3, + xlora_models::{XLoraQLlama, XLoraQPhi3}, +}; +use akin::akin; + +impl TryFrom> for QLlama { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { + let config = params.expect_quantized("`Config` should be GGML Quantized"); + config.try_into_model() + } +} + +impl TryFrom> for XLoraQLlama { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { + let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter"); + config.try_into_model() + } +} + +akin! { + let &models_gguf = [QLlama, QPhi, QPhi3]; + + impl TryFrom>> for *models_gguf { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + let config = params.expect_quantized("`Config` should be GGUF Quantized"); + config.try_into_model() + } + } +} + +akin! { + let &models_gguf_a = [XLoraQLlama, XLoraQPhi3]; + + impl TryFrom>> for *models_gguf_a { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter"); + config.try_into_model() + } + } +} diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 48281c9dd2..9dd26c0e91 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -737,7 +737,7 @@ impl NormalModel for XLoraModel { &self.device } fn is_xlora(&self) -> bool { - false + true } fn max_seq_len(&self) -> usize { self.max_seq_len diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index 2a0b83581d..3b68cbcd81 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -21,6 +21,7 @@ use crate::DeviceMapMetadata; use super::classifier::XLoraClassifier; use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig}; +use crate::utils::model_config as ModelConfig; const MAX_SEQ_LEN: u32 = 4096; const SUPPORTED_LAYERS: [&str; 7] = [ @@ -269,8 +270,8 @@ pub struct ModelWeights { mapper: Option>, } -impl ModelWeights { - pub fn from_ggml( +impl ModelConfig::FromAdapterGGML for ModelWeights { + fn from_ggml( mut ct: ggml_file::Content, gqa: usize, lora_config: &[((String, String), LoraConfig)], @@ -440,9 +441,11 @@ impl ModelWeights { mapper: None, }) } +} +impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] - pub fn from_gguf( + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -710,7 +713,9 @@ impl ModelWeights { mapper: Some(mapper), }) } +} +impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 210f2275be..248bc4175f 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -33,6 +33,7 @@ use super::Cache; use super::NonGranularState; use super::ScalingsMaker; use super::XLoraConfig; +use crate::utils::model_config as ModelConfig; const SUPPORTED_LAYERS: [&str; 4] = [ "self_attn.qkv_proj", @@ -212,9 +213,9 @@ fn precomput_freqs_cis( Ok((cos, sin)) } -impl ModelWeights { +impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] - pub fn from_gguf( + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -349,7 +350,9 @@ impl ModelWeights { }), }) } +} +impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { let mut sum = 0; for layer in self.layers.iter_mut() {