From 1d21c5f2d8a75545135741d615fbd7c41106d5d7 Mon Sep 17 00:00:00 2001 From: Brennan Kinney <5098581+polarathene@users.noreply.github.com> Date: Fri, 31 May 2024 12:27:04 +1200 Subject: [PATCH] refactor: GGUF + GGML Loaders with `ModelKind` (#356) * chore: Communicate actual difference These methods are very verbose, but really only differ by two params to differentiate from Lora vs XLora. * refactor: Introduce `ModelConfig` + `from_gguf` proxy `ModelConfig` groups the common properties used across the `from_gguf()` methods. This will better communicate differences across impl of `from_gguf()`. The quantized xlora models `from_gguf()` now have a prop to param forwarder as a workaround to minimize breakage elsewhere. * refactor: Add `from_ggml` proxy Very similar to the `from_gguf`, except only `quantized_llama.rs` xlora supports this. No `Device` params, slightly different `File` params from GGUF type. * chore: DRY `ggml.rs` + `gguf.rs` common adapter config logic Finally, all this extra boilerplate can be shifted into the model config `Adapter` struct to self-contain in a single method. This required adjusting ownership a little to satisfy the compiler. The original `from_gguf()` and `from_ggml()` methods are unaffected, they still receive the expected params as reference. * refactor(breaking): Leverage traits for `from_gguf()` / `from_ggml()` This introduces a slight breaking change, in that using these `from_gguf()` / `from_ggml()` methods now requires importing the trait into scope. The methods drop the `pub` prefix as they inherit `pub` from the trait requirement itself. The purpose of this trait is to not require each model to duplicate the structs to params mapping helper method. Instead that can be centralized. * chore: DRY - Dedupe prop mapping methods These no longer need to be maintained as copies within the supported model modules. They now leverage the common shared traits and take an annotated type parameter to handle. The syntax for usage is presently a little more verbose than desired. * chore: Contextual commit - Alternative prop mapping approaches For reference, these alternatives could be considered. * refactor: Add equivalent support for quant models without adapters - Impl traits for the non-adapter quant models - Since adapter variants only append parameters and that is now a distinct struct, `model_config` can be defined earlier and a helper `with_adapter()` can convert to the adapter type variant. * chore: Fix typo * refactor: Collapse Lora + XLora arms With a rebase to adopt new methods for `ModelKind`, the shared logic can be hoisted out of the match arms. XLora specific variables were moved into `Adapter::try_new()` (`model_config.rs`) as they can share the same `paths` parameter by adding a separate bool to toggle x-lora usage. By hoisting `model_config` variable out of the match arm, the type would change when calling `with_adapter()`, thus to prevent that the separate `Adapter*` tuple structs have been dropped in favor of `ModelParams` which uses generic `Q` for the quantization type (trait marker) and a separate adapter optional that can be updated. `MapParamsToModel` also is no longer compatible as an approach since the trait bound is ambiguous as there is no distinct adapter type (eg: `for AdapterGGUF`) to impl upon, unique method names also become required to avoid conflict on the same type. - `try_into_model()` + `try_into_model_with_adapter()` for the separate `Q` types (GGUF/GGML) with/without adapter. - Due to new struct introduced, slight change to destructuring. The `impl` also bundles both methods now for each `Q` variant. Order was adjusted to basic followed by adapter methods for each `Q` variant, instead of both basic, then both adapter variations following afterwards. - Likewise the `ggml.rs` and `gguf.rs` methods without the `MapParamsToModel` trait now rely on `TryFrom` trait impl. * refactor: Wrap `ModelParams` into enum for distinct adapter type This approach introduces another generic parameter `MaybeAdapter` to emulate a `Option` that can be used as type to `impl` upon. To continue the unified type usage with an adapter variant in `ggml.rs` / `gguf.rs` pipelines, this must now leverage an enum for the two variants. - Slightly more complexity added as a result. - Adapter `try_into_model()` methods no longer need to check for `Some(Adapter)` to unwrap, since that should always be the case. This is now guaranteed. - However similar logic has bubbled up to the `TryFrom` for all impl due to the enum wrapper, thus this approach may not be much better beyond broader consistency. Likewise with the `with_adapter()` method. To minimize boilerplate in handling unwrapping of the enum in the `TryFrom` methods, `Variantly` has been introduced for it's `expect_variant()` method. As all four types are distinct, the `_with_adapter()` method can also be `try_into_model()` due to separate impl for the new generic param `MaybeAdapter`. * chore: Minor improvements Since the type constraint for `try_into_model()` methods is bound as the return type, it can be inferred without any hint in the `TryFrom`, no need to annotate with `Self`. Use `derive_more` for terser construction of `Config` struct for `ModelParams` variants. * refactor: Use `buildstructor` for builder API This is an alternative approach to build the config. Construction of the config from the param inputs is handled at the end now, not dependent upon separate `new()` + optional `with_adapter()` calls on a mutable variable. Unfortunately `buildstructor` and `typed-builder` APIs don't allow for easy flexibility of builder methods in different scopes (_due to moves_). `derive-builder` can do this but not with the more complex types due to lack of a `Copy` / `Clone`. Thus the `None` option is required as input regardless of if an adapter is needed. * chore: Wrap `model_config` assignment into expression This better communicates the block is only relevant to assigning this value. While the two `is_lora` / `is_xlora` variables are hoisted above due to usage later as metadata inputs. * fix: Drop `is_lora` from `GeneralMetadata` `pipeline/gguf.rs` + `pipeline/ggml.rs` now ensure that `activate_adapters()` works for X-LoRA too. This is assumed as a bugfix due to the `XLoraLlama` model the two adapter kinds share along with code everywhere else checking `is_xlora`, no other usage of `is_lora` seems to be used. - To ensure further ambiguity is avoided, the condition is better communicated as `has_adapter`. - It is unclear if all usage of `is_xlora` is specific to X-LoRA or also intended to be applicable to LoRA since `XLora*` models do impl `is_xlora() -> true` (except Gemma, which is a potential bug). `pipeline/normal.rs` handled it's own `is_xlora` bool differently than `gguf.rs` / `ggml.rs` loaders. - It relied upon`model.is_xlora() && !is_lora`, but we already assume X-LoRA via prior matching on `ModelKind` which now provides this information via it's own `is_x_lora()` method. - Only `xlora_models/gemma.rs` would behave differently with this change, but Gemma might have meant to return `true`? * chore: Match on adapter Matches are only for `Quantized` or `AdapterQuantized` variants with no difference in handling by `AdapterKind` variant used. Additionally restores the `GGUF X-LoRA` bail formatted string. For consistency the non-adapter branch also appends `for GGUF` and the architecture in the lora branch now comes before the `ModelKind`. * chore: Support params via tuple `into()` + add note of possible bug * breaking: Replace `ModelKind` with new version A better approach for the most part at encoding the kind info. * lint(clippy): Appease the lint gods `model_config.rs` GGUF and GGML structs prefixed with `Params`. Two exceptions added as the concerns don't seem to warrant change: - `#[allow(clippy::borrowed_box)]` - `#[allow(clippy::large_enum_variant)]` * chore: Convert from `CRLF` to `LF` This file has no other change in the commit beyond line ending conversion. It was mistakenly using CRLF since creation. * lint(rustfmt): Appease the lint gods * fix: Restore `is_lora` condition `GeneralMetadata` now stores the `ModelKind` for this type of information. `activate_adapters()` error message revised. `mod.rs` version includes contextual comment about X-LoRA not being equivalent. * fix: Gemma X-Lora model `is_xlora()` should return `true` Most likely redundant with `GeneralMetadata` now having `ModelKind` to query, but fixing here until all queries replaced. Additionally updates `model_config.rs` note to clarify not a bug. --- mistralrs-core/Cargo.toml | 3 + mistralrs-core/src/models/quantized_llama.rs | 11 +- mistralrs-core/src/models/quantized_phi2.rs | 7 +- mistralrs-core/src/models/quantized_phi3.rs | 7 +- mistralrs-core/src/pipeline/ggml.rs | 124 +++---- mistralrs-core/src/pipeline/gguf.rs | 187 +++------- mistralrs-core/src/pipeline/mod.rs | 137 +++----- mistralrs-core/src/pipeline/normal.rs | 36 +- mistralrs-core/src/utils/mod.rs | 1 + mistralrs-core/src/utils/model_config.rs | 323 ++++++++++++++++++ mistralrs-core/src/xlora_models/gemma.rs | 2 +- .../src/xlora_models/quantized_llama.rs | 11 +- .../src/xlora_models/quantized_phi3.rs | 7 +- 13 files changed, 519 insertions(+), 337 deletions(-) create mode 100644 mistralrs-core/src/utils/model_config.rs diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 9d4c9c199..9d967e15c 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -55,6 +55,9 @@ once_cell.workspace = true toml = "0.8.12" strum = { version = "0.26", features = ["derive"] } derive_more = { version = "0.99.17", default-features = false, features = ["from"] } +akin = "0.4.0" +variantly = "0.4.0" +buildstructor = "0.5.4" tracing-subscriber.workspace = true reqwest = { version = "0.12.4", features = ["blocking"] } diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index e902b3cd1..70b36e38b 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -10,6 +10,7 @@ use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention, }; use crate::pipeline::{extract_logits, Cache}; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; const MAX_SEQ_LEN: u32 = 4096; @@ -194,8 +195,8 @@ pub struct ModelWeights { mapper: Option>, } -impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { +impl ModelConfig::FromGGML for ModelWeights { + fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let rotary = RotaryEmbedding::new_partial( 10000., @@ -254,8 +255,10 @@ impl ModelWeights { mapper: None, }) } +} - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -383,7 +386,9 @@ impl ModelWeights { mapper: Some(mapper), }) } +} +impl ModelWeights { pub fn forward( &mut self, x: &Tensor, diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index a002f0b89..0c4d3225e 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -9,6 +9,7 @@ use crate::device_map::DeviceMapper; use crate::layers::ScaledDotProductAttention; use crate::layers::{repeat_kv, CausalMasker, QLinear}; use crate::pipeline::{extract_logits, Cache}; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; pub const MAX_SEQ_LEN: usize = 4096; @@ -141,8 +142,8 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { Ok(ln) } -impl ModelWeights { - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -211,7 +212,9 @@ impl ModelWeights { mapper, }) } +} +impl ModelWeights { pub fn forward( &mut self, input_ids: &Tensor, diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 909dd90c2..8149eea84 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -5,6 +5,7 @@ use crate::layers::{ repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention, }; use crate::pipeline::Cache; +use crate::utils::model_config as ModelConfig; use crate::DeviceMapMetadata; use candle_core::quantized::gguf_file; use candle_core::quantized::QMatMul; @@ -159,8 +160,8 @@ fn precomput_freqs_cis( Ok((cos, sin)) } -impl ModelWeights { - pub fn from_gguf( +impl ModelConfig::FromGGUF for ModelWeights { + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -248,7 +249,9 @@ impl ModelWeights { max_seq_len: context_window, }) } +} +impl ModelWeights { pub fn forward(&mut self, input_ids: &Tensor, seqlen_offsets: &[usize]) -> Result { let (_b_sz, seq_len) = input_ids.dims2()?; let mut xs = self.tok_embeddings.forward(input_ids)?; diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 9114f70c9..aaeaea473 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -1,7 +1,7 @@ use super::cache_manager::DefaultCacheManager; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, Pipeline, QuantizationKind, TokenSource, XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -11,8 +11,8 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; -use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG}; use crate::{ @@ -96,12 +96,16 @@ impl GGMLLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Ggml, + }; + Self { config, chat_template, tokenizer_json, model_id, - kind: ModelKind::QuantizedGGML, + kind, quantized_filename, quantized_model_id, ..Default::default() @@ -138,7 +142,8 @@ impl GGMLLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraGGML; + self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into(); + self.with_adapter( xlora_model_id, xlora_order, @@ -148,7 +153,8 @@ impl GGMLLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraGGML; + self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into(); + self.with_adapter(lora_model_id, lora_order, false, None) } @@ -236,7 +242,7 @@ impl Loader for GGMLLoader { if in_situ_quant.is_some() { anyhow::bail!( - "You are trying to in-situ quantize a GGUF model. This will not do anything." + "You are trying to in-situ quantize a GGML model. This will not do anything." ); } if !mapper.is_dummy() { @@ -267,69 +273,33 @@ impl Loader for GGMLLoader { info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`."); } - let mut is_lora = false; - let model = match self.kind { - ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml(model, self.config.gqa)?), - ModelKind::XLoraGGML => { - let vb = from_mmaped_safetensors( - vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - Model::XLoraLlama(XLoraQLlama::from_ggml( - model, - self.config.gqa, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?) + let has_adapter = self.kind.is_adapted(); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); + + let model_config = { + // Base config (quantization only): + let quant = ModelConfig::ParamsGGML((model, self.config.gqa).into()); + + // With optional adapter config: + let mut adapter = None; + if has_adapter { + adapter.replace(ModelConfig::Adapter::try_new( + paths, device, silent, is_xlora, + )?); } - ModelKind::LoraGGML => { - is_lora = true; - let vb = from_mmaped_safetensors( - vec![], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - Model::XLoraLlama(XLoraQLlama::from_ggml( - model, - self.config.gqa, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?) + + ModelConfig::ModelParams::builder() + .quant(quant) + .and_adapter(adapter) + .build() + }; + + // Config into model: + // NOTE: No architecture to infer like GGUF, Llama model is implicitly matched + let model = match self.kind { + ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?), + ModelKind::AdapterQuantized { .. } => { + Model::XLoraLlama(XLoraQLlama::try_from(model_config)?) } _ => unreachable!(), }; @@ -345,10 +315,6 @@ impl Loader for GGMLLoader { Model::XLoraLlama(ref xl) => xl.max_seq_len, }; let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = match &model { - Model::Llama(_) => false, - Model::XLoraLlama(_) => !is_lora, - }; let num_hidden_layers = match model { Model::Llama(ref model) => model.cache.lock().len(), Model::XLoraLlama(ref model) => model.cache.lock().len(), @@ -372,10 +338,10 @@ impl Loader for GGMLLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } @@ -508,14 +474,16 @@ impl Pipeline for GGMLPipeline { } } fn activate_adapters(&mut self, adapter_names: Vec) -> anyhow::Result { - if !self.metadata.is_lora { - anyhow::bail!("Cannot activate adapters non-LoRA models.") + let is_lora = self.metadata.kind.is_adapted_and(|a| a.is_lora()); + if !is_lora { + anyhow::bail!("Activating adapters is only supported for models fine-tuned with LoRA.") } + match self.model { - Model::Llama(_) => unreachable!(), Model::XLoraLlama(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), + _ => unreachable!(), } } } diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 71520b1d6..abe2b2c48 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -1,7 +1,8 @@ use super::cache_manager::DefaultCacheManager; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, Pipeline, PrettyName, QuantizationKind, TokenSource, + XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -12,8 +13,8 @@ use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; -use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG}; use crate::{ @@ -134,11 +135,15 @@ impl GGUFLoaderBuilder { quantized_model_id: String, quantized_filename: String, ) -> Self { + let kind = ModelKind::Quantized { + quant: QuantizationKind::Gguf, + }; + Self { config, chat_template, model_id: tok_model_id, - kind: ModelKind::QuantizedGGUF, + kind, quantized_filename, quantized_model_id, ..Default::default() @@ -175,7 +180,8 @@ impl GGUFLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraGGUF; + self.kind = (AdapterKind::XLora, QuantizationKind::Gguf).into(); + self.with_adapter( xlora_model_id, xlora_order, @@ -185,7 +191,8 @@ impl GGUFLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraGGUF; + self.kind = (AdapterKind::Lora, QuantizationKind::Gguf).into(); + self.with_adapter(lora_model_id, lora_order, false, None) } @@ -376,123 +383,43 @@ impl Loader for GGUFLoader { } }; - let mut is_lora = false; + let has_adapter = self.kind.is_adapted(); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); + + let model_config = { + // Base config (quantization only): + let quant = ModelConfig::ParamsGGUF((model, &mut file).into(), (device, mapper).into()); + + // With optional adapter config: + let mut adapter = None; + if has_adapter { + adapter.replace(ModelConfig::Adapter::try_new( + paths, device, silent, is_xlora, + )?); + } + + ModelConfig::ModelParams::builder() + .quant(quant) + .and_adapter(adapter) + .build() + }; + + // Config into model: let model = match self.kind { - ModelKind::QuantizedGGUF => match arch { - GGUFArchitecture::Llama => { - Model::Llama(QLlama::from_gguf(model, &mut file, device, mapper)?) - } - GGUFArchitecture::Phi2 => { - Model::Phi2(QPhi::from_gguf(model, &mut file, device, mapper)?) - } - GGUFArchitecture::Phi3 => { - Model::Phi3(QPhi3::from_gguf(model, &mut file, device, mapper)?) - } - a => bail!("Unsupported architecture `{a:?}`"), + ModelKind::Quantized { .. } => match arch { + GGUFArchitecture::Llama => Model::Llama(QLlama::try_from(model_config)?), + GGUFArchitecture::Phi2 => Model::Phi2(QPhi::try_from(model_config)?), + GGUFArchitecture::Phi3 => Model::Phi3(QPhi3::try_from(model_config)?), + a => bail!("Unsupported architecture `{a:?}` for GGUF"), + }, + ModelKind::AdapterQuantized { adapter, .. } => match arch { + GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::try_from(model_config)?), + GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::try_from(model_config)?), + a => bail!( + "Unsupported architecture `{a:?}` for GGUF {kind}", + kind = adapter.pretty_name() + ), }, - ModelKind::XLoraGGUF => { - let vb = from_mmaped_safetensors( - vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - match arch { - GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - Some(paths.get_classifier_config().as_ref().unwrap().clone()), - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - a => bail!("Unsupported architecture for GGUF X-LoRA `{a:?}`"), - } - } - ModelKind::LoraGGUF => { - is_lora = true; - let vb = from_mmaped_safetensors( - vec![], - paths - .get_adapter_filenames() - .as_ref() - .unwrap() - .iter() - .map(|(_, x)| (*x).to_owned()) - .collect::>(), - DType::F32, - device, - silent, - )?; - - match arch { - GGUFArchitecture::Llama => Model::XLoraLlama(XLoraQLlama::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - GGUFArchitecture::Phi3 => Model::XLoraPhi3(XLoraQPhi3::from_gguf( - model, - &mut file, - device, - paths.get_adapter_configs().as_ref().unwrap(), - &vb, - paths.get_ordering().as_ref().unwrap(), - None, - mapper, - &load_preload_adapters( - paths.get_lora_preload_adapter_info(), - DType::F32, - device, - silent, - )?, - )?), - a => bail!("Unsupported architecture for GGUF LoRA `{a:?}`"), - } - } _ => unreachable!(), }; @@ -509,10 +436,6 @@ impl Loader for GGUFLoader { Model::XLoraPhi3(ref p) => p.max_seq_len, }; let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = match &model { - Model::Llama(_) | Model::Phi2(_) | Model::Phi3(_) => false, - Model::XLoraLlama(_) | Model::XLoraPhi3(_) => !is_lora, - }; let num_hidden_layers = match model { Model::Llama(ref model) => model.cache.lock().len(), Model::Phi2(ref model) => model.cache.lock().len(), @@ -553,10 +476,10 @@ impl Loader for GGUFLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } @@ -685,19 +608,19 @@ impl Pipeline for GGUFPipeline { } } fn activate_adapters(&mut self, adapter_names: Vec) -> anyhow::Result { - if !self.metadata.is_lora { - anyhow::bail!("Cannot activate adapters non-LoRA models.") + let is_lora = self.metadata.kind.is_adapted_and(|a| a.is_lora()); + if !is_lora { + anyhow::bail!("Activating adapters is only supported for models fine-tuned with LoRA.") } + match self.model { - Model::Llama(_) => unreachable!(), - Model::Phi2(_) => unreachable!(), - Model::Phi3(_) => unreachable!(), Model::XLoraLlama(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), Model::XLoraPhi3(ref mut model) => model .activate_adapters(adapter_names) .map_err(anyhow::Error::msg), + _ => unreachable!(), } } } diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 5dae166a3..b3c07b843 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -37,7 +37,7 @@ use rand_isaac::Isaac64Rng; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use serde_json::Value; pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline}; -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::path::Path; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -226,31 +226,12 @@ impl fmt::Display for TokenSource { } } -#[derive(Clone, Default)] /// The kind of model to build. +#[derive(Clone, Default, derive_more::From, strum::Display)] pub enum ModelKind { - #[default] - Normal, - XLoraNormal, - XLoraGGUF, - XLoraGGML, - QuantizedGGUF, - QuantizedGGML, - LoraGGUF, - LoraGGML, - LoraNormal, - Speculative { - target: Box, - draft: Box, - }, -} - -// TODO: Future replacement for `ModelKind` above: -#[derive(Default, derive_more::From, strum::Display)] -pub enum ModelKindB { #[default] #[strum(to_string = "normal (no quant, no adapters)")] - Plain, + Normal, #[strum(to_string = "quantized from {quant} (no adapters)")] Quantized { quant: QuantizationKind }, @@ -264,8 +245,6 @@ pub enum ModelKindB { quant: QuantizationKind, }, - // TODO: This would need to be later changed to reference `Self`, but this current way - // avoids having to handle the conversion logic with `ModelKind`. #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")] Speculative { target: Box, @@ -273,21 +252,40 @@ pub enum ModelKindB { }, } -#[derive(Clone, Copy, strum::Display, strum::EnumIs)] +#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] #[strum(serialize_all = "kebab-case")] pub enum QuantizationKind { + /// GGML Ggml, + /// GGUF Gguf, } -#[derive(Clone, Copy, strum::Display, strum::EnumIs)] +#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)] #[strum(serialize_all = "kebab-case")] pub enum AdapterKind { + /// LoRA Lora, + /// X-LoRA XLora, } -impl ModelKindB { +// For the proper name as formatted via doc comment for a variant +pub trait PrettyName: strum::EnumMessage + ToString { + fn pretty_name(&self) -> String { + match self.get_documentation() { + Some(s) => s.to_string(), + // Instead of panic via expect(), + // fallback to default kebab-case: + None => self.to_string(), + } + } +} + +impl PrettyName for AdapterKind {} +impl PrettyName for QuantizationKind {} + +impl ModelKind { // Quantized helpers: pub fn is_quantized(&self) -> bool { self.quantized_kind().iter().any(|q| q.is_some()) @@ -298,14 +296,14 @@ impl ModelKindB { } pub fn quantized_kind(&self) -> Vec> { - use ModelKindB::*; + use ModelKind::*; match self { - Plain | Adapter { .. } => vec![None], + Normal | Adapter { .. } => vec![None], Quantized { quant } | AdapterQuantized { quant, .. } => vec![Some(*quant)], Speculative { target, draft } => { - let t = ModelKindB::from(*target.clone()); - let d = ModelKindB::from(*draft.clone()); + let t = *target.clone(); + let d = *draft.clone(); [t.quantized_kind(), d.quantized_kind()].concat() } @@ -322,14 +320,14 @@ impl ModelKindB { } pub fn adapted_kind(&self) -> Vec> { - use ModelKindB::*; + use ModelKind::*; match self { - Plain | Quantized { .. } => vec![None], + Normal | Quantized { .. } => vec![None], Adapter { adapter } | AdapterQuantized { adapter, .. } => vec![Some(*adapter)], Speculative { target, draft } => { - let t = ModelKindB::from(*target.clone()); - let d = ModelKindB::from(*draft.clone()); + let t = *target.clone(); + let d = *draft.clone(); [t.adapted_kind(), d.adapted_kind()].concat() } @@ -337,65 +335,6 @@ impl ModelKindB { } } -// TODO: Temporary compatibility layers follow (until a future PR follow-up introduces a breaking change) -impl Display for ModelKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", ModelKindB::from(self.clone())) - } -} - -// Delegate to `ModelKindB` methods: -impl ModelKind { - // Quantized helpers: - pub fn is_quantized(&self) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_quantized() - } - - pub fn is_quantized_and(&self, f: impl FnMut(QuantizationKind) -> bool) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_quantized_and(f) - } - - pub fn quantized_kind(&self) -> Vec> { - let k = ModelKindB::from(self.clone()); - k.quantized_kind() - } - - // Adapter helpers: - pub fn is_adapted(&self) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_adapted() - } - - pub fn is_adapted_and(&self, f: impl FnMut(AdapterKind) -> bool) -> bool { - let k = ModelKindB::from(self.clone()); - k.is_adapted_and(f) - } - - pub fn adapted_kind(&self) -> Vec> { - let k = ModelKindB::from(self.clone()); - k.adapted_kind() - } -} - -impl From for ModelKindB { - fn from(kind: ModelKind) -> Self { - match kind { - ModelKind::Normal => ModelKindB::Plain, - ModelKind::QuantizedGGML => (QuantizationKind::Ggml).into(), - ModelKind::QuantizedGGUF => (QuantizationKind::Gguf).into(), - ModelKind::XLoraNormal => (AdapterKind::XLora).into(), - ModelKind::XLoraGGML => (AdapterKind::XLora, QuantizationKind::Ggml).into(), - ModelKind::XLoraGGUF => (AdapterKind::XLora, QuantizationKind::Gguf).into(), - ModelKind::LoraNormal => (AdapterKind::Lora).into(), - ModelKind::LoraGGML => (AdapterKind::Lora, QuantizationKind::Ggml).into(), - ModelKind::LoraGGUF => (AdapterKind::Lora, QuantizationKind::Gguf).into(), - ModelKind::Speculative { target, draft } => (target, draft).into(), - } - } -} - /// The `Loader` trait abstracts the loading process. The primary entrypoint is the /// `load_model` method. /// @@ -458,10 +397,11 @@ pub struct GeneralMetadata { pub repeat_last_n: usize, pub tok_trie: Arc, pub has_no_kv_cache: bool, - pub is_xlora: bool, pub num_hidden_layers: usize, pub eos_tok: Vec, - pub is_lora: bool, + pub kind: ModelKind, + // TODO: Replace is_xlora queries to check via kind instead: + pub is_xlora: bool, } pub enum AdapterInstruction { @@ -792,7 +732,10 @@ pub trait NormalModel { Ok(()) } fn activate_adapters(&mut self, _: Vec) -> candle_core::Result { - candle_core::bail!("Unable to activate adapters for model without adapters"); + // NOTE: While X-LoRA shares a similar name, it is not equivalent. Its adapter set must remain the same. + candle_core::bail!( + "Activating adapters is only supported for models fine-tuned with LoRA." + ); } } diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index bf36cb901..c97e194b4 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -4,8 +4,9 @@ use super::loaders::{ Phi3Loader, Qwen2Loader, }; use super::{ - get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs, - ModelKind, ModelPaths, NormalModel, NormalModelLoader, Pipeline, TokenSource, XLoraPaths, + get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader, + ModelInputs, ModelKind, ModelPaths, NormalModel, NormalModelLoader, Pipeline, TokenSource, + XLoraPaths, }; use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; @@ -130,7 +131,9 @@ impl NormalLoaderBuilder { no_kv_cache: bool, tgt_non_granular_index: Option, ) -> Self { - self.kind = ModelKind::XLoraNormal; + self.kind = ModelKind::Adapter { + adapter: AdapterKind::XLora, + }; self.with_adapter( xlora_model_id, xlora_order, @@ -140,7 +143,9 @@ impl NormalLoaderBuilder { } pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self { - self.kind = ModelKind::LoraNormal; + self.kind = ModelKind::Adapter { + adapter: AdapterKind::Lora, + }; self.with_adapter(lora_model_id, lora_order, false, None) } @@ -242,7 +247,7 @@ impl Loader for NormalLoader { Device::Cpu }; - let is_lora = self.kind.is_adapted_and(|a| a.is_lora()); + let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora()); let mut model = match self.kind { ModelKind::Normal => normal_model_loader!( @@ -258,7 +263,9 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::XLoraNormal => xlora_model_loader!( + ModelKind::Adapter { + adapter: AdapterKind::XLora, + } => xlora_model_loader!( paths, dtype, default_dtype, @@ -271,7 +278,9 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::LoraNormal => lora_model_loader!( + ModelKind::Adapter { + adapter: AdapterKind::Lora, + } => lora_model_loader!( paths, dtype, default_dtype, @@ -284,13 +293,7 @@ impl Loader for NormalLoader { in_situ_quant.is_some(), device.clone() ), - ModelKind::QuantizedGGUF - | ModelKind::QuantizedGGML - | ModelKind::XLoraGGUF - | ModelKind::XLoraGGML - | ModelKind::LoraGGUF - | ModelKind::LoraGGML - | ModelKind::Speculative { .. } => unreachable!(), + _ => unreachable!(), }; let tokenizer = get_tokenizer(paths.get_tokenizer_filename())?; @@ -305,7 +308,6 @@ impl Loader for NormalLoader { let max_seq_len = model.max_seq_len(); let tok_trie: Arc = build_tok_trie(tokenizer.clone()).into(); - let is_xlora = model.is_xlora() && !is_lora; let num_hidden_layers = model.cache().lock().len(); let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer); Ok(Arc::new(Mutex::new(NormalPipeline { @@ -326,10 +328,10 @@ impl Loader for NormalLoader { repeat_last_n: self.config.repeat_last_n, tok_trie, has_no_kv_cache: self.no_kv_cache, - is_xlora, num_hidden_layers, eos_tok: eos, - is_lora, + kind: self.kind.clone(), + is_xlora, }, }))) } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 72fb055d4..c07f63e3f 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod model_config; pub(crate) mod tokenizer; pub(crate) mod tokens; pub(crate) mod varbuilder_utils; diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs new file mode 100644 index 000000000..7315f27b5 --- /dev/null +++ b/mistralrs-core/src/utils/model_config.rs @@ -0,0 +1,323 @@ +use super::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; +use anyhow::Result; +use candle_core::quantized::{ggml_file, gguf_file}; +use candle_nn::VarBuilder; +use std::{collections::HashMap, path::PathBuf}; + +use crate::{ + lora::{LoraConfig, Ordering}, + pipeline::ModelPaths, + xlora_models::XLoraConfig, + DeviceMapMetadata, +}; + +#[derive(derive_more::From)] +pub struct FileGGML { + pub ct: ggml_file::Content, + pub gqa: usize, +} + +#[derive(derive_more::From)] +pub struct FileGGUF<'a> { + pub ct: gguf_file::Content, + pub reader: &'a mut std::fs::File, +} + +#[derive(derive_more::From)] +pub struct Device<'a> { + pub device: &'a candle_core::Device, + pub mapper: DeviceMapMetadata, +} + +pub struct Adapter<'a> { + pub xlora_config: Option, + pub lora_config: &'a [((String, String), LoraConfig)], + pub vb: VarBuilder<'a>, + pub ordering: &'a Ordering, + pub preload_adapters: Option, LoraConfig)>>, +} + +impl<'a> Adapter<'a> { + // NOTE: It is not possible to store references for values returned by: load_preload_adapters() + from_mmaped_safetensors(), + // As referenced value would drop after this method, Adapter takes ownership of vb + preload_adapters + // and then passes by reference to the `from_gguf()` / `from_ggml()` methods when proxying to params. + // NOTE: Due to reference usage persisting in returned struct, additional lifetime annotations were required. + #[allow(clippy::borrowed_box)] + pub fn try_new<'b: 'a>( + paths: &'b Box, + device: &'b candle_core::Device, + silent: bool, + is_xlora: bool, + ) -> Result { + let lora_config = paths.get_adapter_configs().as_ref().unwrap(); + let ordering = paths.get_ordering().as_ref().unwrap(); + let preload_adapters = load_preload_adapters( + paths.get_lora_preload_adapter_info(), + candle_core::DType::F32, + device, + silent, + )?; + + // X-LoRA support: + let mut xlora_paths: Vec = vec![]; + let mut xlora_config: Option = None; + if is_xlora { + xlora_paths = vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()]; + xlora_config = Some(paths.get_classifier_config().as_ref().unwrap().clone()); + } + + // Create VarBuilder: + // TODO: `from_mmaped_safetensors` has `xlora_paths` as the 2nd param (_valid but params need to be named better_) + let vb = from_mmaped_safetensors( + xlora_paths, + paths + .get_adapter_filenames() + .as_ref() + .unwrap() + .iter() + .map(|(_, x)| (*x).to_owned()) + .collect::>(), + candle_core::DType::F32, + device, + silent, + )?; + + Ok(Self { + lora_config, + xlora_config, + vb, + ordering, + preload_adapters, + }) + } +} + +// New type wrappers that segment the distinct parameter sets used by `from_ggml()` + `from_gguf()` methods: +pub struct ParamsGGML(pub FileGGML); +pub struct ParamsGGUF<'a>(pub FileGGUF<'a>, pub Device<'a>); + +// A `None` type vs the `Some` type (`Adapter<'a>`) +pub struct NoAdapter {} + +// Marker traits to restrict type input: +// (required workaround to support impl on subtypes, otherwise would use an enum) +pub trait QuantParams {} +impl QuantParams for ParamsGGML {} +impl QuantParams for ParamsGGUF<'_> {} + +// Emulates `Option` but is compatible as a type bound in `impl` for Some vs None +pub trait MaybeAdapter {} +impl MaybeAdapter for Adapter<'_> {} +impl MaybeAdapter for NoAdapter {} + +// `derive_more::From` provides a terser construction for enum variants of `ModelParams`. +#[derive(derive_more::From)] +pub struct Config { + pub quant: Q, + pub adapter: A, +} + +// NOTE: Variantly used for `.expect_quantized()` / `.expect_adapted()` methods +// `where` clause required due to bug with inline bounds: +// https://github.com/luker-os/variantly/pull/16 +#[allow(clippy::large_enum_variant)] +#[derive(variantly::Variantly)] +pub enum ModelParams<'a, Q> +where + Q: QuantParams, +{ + Quantized(Config), + Adapted(Config>), +} + +// A `builder()` method is derived from the `new()` method and it's params (derived builder struct fields). +// NOTE: Intended to be built via fluent API in a single line, cannot conditionally append params. +// `.adapter(Adapter<' >)` or for conditional usage `.and_adapter(Option)` can be used. +// Otherwise omitting an `.adapter()` call prior to calling `build()` is ok, defaults to `None`. +#[buildstructor::buildstructor] +impl<'a, Q: QuantParams> ModelParams<'a, Q> { + #[builder] + pub fn new<'b: 'a>(quant: Q, adapter: Option>) -> Self { + match adapter { + None => Self::Quantized((quant, NoAdapter {}).into()), + Some(a) => Self::Adapted((quant, a).into()), + } + } +} + +// Traits for the existing methods used across various model types to impl `from_ggml()` / `from_gguf()` +// Basic: +pub trait FromGGML { + fn from_ggml(ct: ggml_file::Content, gqa: usize) -> Result + where + Self: Sized; +} + +pub trait FromGGUF { + fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &candle_core::Device, + mapper: DeviceMapMetadata, + ) -> Result + where + Self: Sized; +} + +// Extended variants: +pub trait FromAdapterGGML { + fn from_ggml( + ct: ggml_file::Content, + gqa: usize, + lora_config: &[((String, String), LoraConfig)], + vb: &VarBuilder, + ordering: &Ordering, + xlora_config: Option, + preload_adapters: &Option>, + ) -> Result + where + Self: Sized; +} +pub trait FromAdapterGGUF { + #[allow(clippy::too_many_arguments)] + fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &candle_core::Device, + lora_config: &[((String, String), LoraConfig)], + vb: &VarBuilder, + ordering: &Ordering, + xlora_config: Option, + mapper: DeviceMapMetadata, + preload_adapters: &Option>, + ) -> Result + where + Self: Sized; +} + +// NOTE: Below is a workaround to proxy params to the existing API methods `get_gguf()` / `get_gmml()` traits covered above. +impl Config { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGML(FileGGML { ct, gqa }) = self.quant; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_ggml(ct, gqa) + } +} + +impl Config> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGML(FileGGML { ct, gqa }) = self.quant; + + let Adapter { + xlora_config, + lora_config, + vb, + ordering, + preload_adapters, + } = self.adapter; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_ggml( + ct, + gqa, + lora_config, + &vb, + ordering, + xlora_config, + &preload_adapters, + ) + } +} + +impl Config, NoAdapter> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_gguf(ct, reader, device, mapper) + } +} + +impl Config, Adapter<'_>> { + pub fn try_into_model(self) -> Result { + // Destructure props: + let ParamsGGUF(FileGGUF { ct, reader }, Device { device, mapper }) = self.quant; + + let Adapter { + xlora_config, + lora_config, + vb, + ordering, + preload_adapters, + } = self.adapter; + + // Forwards all structured fields above into the required flattened param sequence: + T::from_gguf( + ct, + reader, + device, + lora_config, + &vb, + ordering, + xlora_config, + mapper, + &preload_adapters, + ) + } +} + +use crate::{ + models::quantized_llama::ModelWeights as QLlama, + models::quantized_phi2::ModelWeights as QPhi, + models::quantized_phi3::ModelWeights as QPhi3, + xlora_models::{XLoraQLlama, XLoraQPhi3}, +}; +use akin::akin; + +impl TryFrom> for QLlama { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { + let config = params.expect_quantized("`Config` should be GGML Quantized"); + config.try_into_model() + } +} + +impl TryFrom> for XLoraQLlama { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { + let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter"); + config.try_into_model() + } +} + +akin! { + let &models_gguf = [QLlama, QPhi, QPhi3]; + + impl TryFrom>> for *models_gguf { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + let config = params.expect_quantized("`Config` should be GGUF Quantized"); + config.try_into_model() + } + } +} + +akin! { + let &models_gguf_a = [XLoraQLlama, XLoraQPhi3]; + + impl TryFrom>> for *models_gguf_a { + type Error = candle_core::Error; + + fn try_from(params: ModelParams<'_, ParamsGGUF<'_>>) -> Result { + let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter"); + config.try_into_model() + } + } +} diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 48281c9dd..9dd26c0e9 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -737,7 +737,7 @@ impl NormalModel for XLoraModel { &self.device } fn is_xlora(&self) -> bool { - false + true } fn max_seq_len(&self) -> usize { self.max_seq_len diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index 2a0b83581..3b68cbcd8 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -21,6 +21,7 @@ use crate::DeviceMapMetadata; use super::classifier::XLoraClassifier; use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig}; +use crate::utils::model_config as ModelConfig; const MAX_SEQ_LEN: u32 = 4096; const SUPPORTED_LAYERS: [&str; 7] = [ @@ -269,8 +270,8 @@ pub struct ModelWeights { mapper: Option>, } -impl ModelWeights { - pub fn from_ggml( +impl ModelConfig::FromAdapterGGML for ModelWeights { + fn from_ggml( mut ct: ggml_file::Content, gqa: usize, lora_config: &[((String, String), LoraConfig)], @@ -440,9 +441,11 @@ impl ModelWeights { mapper: None, }) } +} +impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] - pub fn from_gguf( + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -710,7 +713,9 @@ impl ModelWeights { mapper: Some(mapper), }) } +} +impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 210f2275b..248bc4175 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -33,6 +33,7 @@ use super::Cache; use super::NonGranularState; use super::ScalingsMaker; use super::XLoraConfig; +use crate::utils::model_config as ModelConfig; const SUPPORTED_LAYERS: [&str; 4] = [ "self_attn.qkv_proj", @@ -212,9 +213,9 @@ fn precomput_freqs_cis( Ok((cos, sin)) } -impl ModelWeights { +impl ModelConfig::FromAdapterGGUF for ModelWeights { #[allow(clippy::too_many_arguments)] - pub fn from_gguf( + fn from_gguf( ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -349,7 +350,9 @@ impl ModelWeights { }), }) } +} +impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { let mut sum = 0; for layer in self.layers.iter_mut() {