From cfe2fd3674c9853e8f3738d7705ed727586efa9d Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 28 May 2024 20:10:38 -0400 Subject: [PATCH 1/2] Add examples readme --- examples/README.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 examples/README.md diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 000000000..043a2211d --- /dev/null +++ b/examples/README.md @@ -0,0 +1,4 @@ +# Examples +- Python: [examples here](python) +- HTTP Server: [examples here](server) +- Rust: [examples here](../mistralrs/examples/) \ No newline at end of file From ddba24b2813cb2614f9f40adacd59e96caabf917 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 28 May 2024 21:12:23 -0400 Subject: [PATCH 2/2] Add an example and fixes --- chat_templates/llama2.json | 3 + chat_templates/llama3.json | 3 + chat_templates/mistral.json | 3 + chat_templates/phi3.json | 3 + mistralrs-core/src/pipeline/chat_template.rs | 29 ++++---- mistralrs-core/src/pipeline/gguf.rs | 72 +++++++++++++------ mistralrs-core/src/pipeline/gguf_tokenizer.rs | 21 +++++- mistralrs-core/src/pipeline/macros.rs | 20 ++++-- mistralrs-core/src/pipeline/mod.rs | 12 ++-- mistralrs/Cargo.toml | 4 ++ mistralrs/examples/gguf_locally/main.rs | 64 +++++++++++++++++ mistralrs/examples/quantized/main.rs | 1 + 12 files changed, 188 insertions(+), 47 deletions(-) create mode 100644 chat_templates/llama2.json create mode 100644 chat_templates/llama3.json create mode 100644 chat_templates/mistral.json create mode 100644 chat_templates/phi3.json create mode 100644 mistralrs/examples/gguf_locally/main.rs diff --git a/chat_templates/llama2.json b/chat_templates/llama2.json new file mode 100644 index 000000000..800a077f2 --- /dev/null +++ b/chat_templates/llama2.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" +} \ No newline at end of file diff --git a/chat_templates/llama3.json b/chat_templates/llama3.json new file mode 100644 index 000000000..61bafeb2e --- /dev/null +++ b/chat_templates/llama3.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" +} \ No newline at end of file diff --git a/chat_templates/mistral.json b/chat_templates/mistral.json new file mode 100644 index 000000000..15544fda6 --- /dev/null +++ b/chat_templates/mistral.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +} \ No newline at end of file diff --git a/chat_templates/phi3.json b/chat_templates/phi3.json new file mode 100644 index 000000000..6d92f29e6 --- /dev/null +++ b/chat_templates/phi3.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}" +} \ No newline at end of file diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs index ee7dfa115..e419b8901 100644 --- a/mistralrs-core/src/pipeline/chat_template.rs +++ b/mistralrs-core/src/pipeline/chat_template.rs @@ -30,9 +30,9 @@ fn raise_exception(msg: String) -> Result { } #[derive(Debug, Deserialize)] -pub struct Unk(#[serde(with = "either::serde_untagged")] pub Either); -#[derive(Debug, Deserialize)] -pub struct Bos(#[serde(with = "either::serde_untagged")] pub Either); +pub struct BeginEndUnkTok( + #[serde(with = "either::serde_untagged")] pub Either, +); #[allow(dead_code)] #[derive(Debug, Deserialize)] @@ -41,23 +41,22 @@ pub struct ChatTemplate { add_eos_token: Option, added_tokens_decoder: Option>, additional_special_tokens: Option>, - pub bos_token: Option, + pub bos_token: Option, /// Jinja format chat templating for chat completion. /// See: https://huggingface.co/docs/transformers/chat_templating pub chat_template: Option, clean_up_tokenization_spaces: Option, device_map: Option, - #[serde(with = "either::serde_untagged")] - pub eos_token: Either, + pub eos_token: Option, legacy: Option, - model_max_length: f64, + model_max_length: Option, pad_token: Option, sp_model_kwargs: Option>, spaces_between_special_tokens: Option, - tokenizer_class: String, + tokenizer_class: Option, truncation_size: Option, - pub unk_token: Option, + pub unk_token: Option, use_default_system_prompt: Option, } @@ -66,10 +65,10 @@ impl ChatTemplate { self.chat_template.is_some() } - pub fn eos_tok(&self) -> String { - match self.eos_token { - Either::Left(ref lit) => lit.clone(), - Either::Right(ref added) => added.content.clone(), + pub fn eos_tok(&self) -> Option { + match self.eos_token.as_ref()?.0 { + Either::Left(ref lit) => Some(lit.clone()), + Either::Right(ref added) => Some(added.content.clone()), } } @@ -93,7 +92,7 @@ pub fn calculate_eos_tokens( gen_conf: Option, tokenizer: &Tokenizer, ) -> Vec { - let mut eos_tok_ids = vec![chat_template.eos_tok()]; + let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default(); let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default(); for alternate in SUPPORTED_ALTERNATE_EOS { @@ -173,7 +172,7 @@ pub fn apply_chat_template_to( add_generation_prompt: bool, template: &str, bos_tok: Option, - eos_tok: &str, + eos_tok: Option, unk_tok: Option, ) -> Result { let mut env = Environment::new(); diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index ae3bb9dca..71520b1d6 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -6,12 +6,13 @@ use super::{ use crate::aici::bintokens::build_tok_trie; use crate::aici::toktree::TokTrie; use crate::lora::Ordering; -use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig}; -use crate::pipeline::gguf_tokenizer::convert_ggml_to_hf_tokenizer; +use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig}; +use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult}; use crate::pipeline::{get_chat_template, Cache}; use crate::pipeline::{ChatTemplate, LocalModelPaths}; use crate::prefix_cacher::PrefixCacheManager; use crate::sequence::Sequence; +use crate::utils::tokenizer::get_tokenizer; use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use crate::xlora_models::NonGranularState; use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG}; @@ -28,6 +29,7 @@ use candle_core::quantized::{ GgmlDType, }; use candle_core::{DType, Device, Tensor}; +use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rand_isaac::Isaac64Rng; use std::fs; @@ -61,10 +63,10 @@ pub struct GGUFPipeline { } pub struct GGUFLoader { - model_id: String, + model_id: Option, config: GGUFSpecificConfig, - quantized_model_id: Option, - quantized_filename: Option, + quantized_model_id: String, + quantized_filename: String, xlora_model_id: Option, xlora_order: Option, no_kv_cache: bool, @@ -189,7 +191,7 @@ impl GGUFLoaderBuilder { pub fn build(self) -> Box { Box::new(GGUFLoader { - model_id: self.model_id.unwrap(), + model_id: self.model_id, config: self.config, xlora_model_id: self.xlora_model_id, kind: self.kind, @@ -197,8 +199,8 @@ impl GGUFLoaderBuilder { no_kv_cache: self.no_kv_cache, chat_template: self.chat_template, tgt_non_granular_index: self.tgt_non_granular_index, - quantized_filename: Some(self.quantized_filename), - quantized_model_id: Some(self.quantized_model_id), + quantized_filename: self.quantized_filename, + quantized_model_id: self.quantized_model_id, }) } } @@ -208,8 +210,8 @@ impl GGUFLoader { pub fn new( model_id: Option, config: GGUFSpecificConfig, - quantized_model_id: Option, - quantized_filename: Option, + quantized_model_id: String, + quantized_filename: String, xlora_model_id: Option, kind: ModelKind, xlora_order: Option, @@ -218,13 +220,15 @@ impl GGUFLoader { tgt_non_granular_index: Option, ) -> Self { let model_id = if let Some(id) = model_id { - id - } else { + Some(id) + } else if let Some(xlora_order) = xlora_order.clone() { info!( "Using adapter base model ID: `{}`", - xlora_order.as_ref().unwrap().base_model_id + xlora_order.base_model_id ); - xlora_order.as_ref().unwrap().base_model_id.clone() + Some(xlora_order.base_model_id.clone()) + } else { + None }; Self { model_id, @@ -280,8 +284,8 @@ impl Loader for GGUFLoader { &token_source, revision, self, - self.quantized_model_id, - self.quantized_filename, + self.quantized_model_id.clone(), + self.quantized_filename.clone(), silent ); self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) @@ -356,7 +360,21 @@ impl Loader for GGUFLoader { info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); } - let tokenizer = convert_ggml_to_hf_tokenizer(&model)?; + let ConversionResult { + tokenizer, + bos, + eos, + unk, + } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() { + convert_ggml_to_hf_tokenizer(&model)? + } else { + ConversionResult { + tokenizer: get_tokenizer(paths.get_tokenizer_filename())?, + bos: None, + eos: None, + unk: None, + } + }; let mut is_lora = false; let model = match self.kind { @@ -481,7 +499,7 @@ impl Loader for GGUFLoader { let gen_conf: Option = paths .get_gen_conf_filename() .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap()); - let chat_template = get_chat_template(paths, &self.chat_template); + let mut chat_template = get_chat_template(paths, &self.chat_template); let max_seq_len = match model { Model::Llama(ref l) => l.max_seq_len, @@ -502,6 +520,17 @@ impl Loader for GGUFLoader { Model::Phi3(ref model) => model.cache.lock().len(), Model::XLoraPhi3(ref model) => model.cache.lock().len(), }; + + if chat_template.bos_token.is_none() && bos.is_some() { + chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap()))); + } + if chat_template.eos_token.is_none() && eos.is_some() { + chat_template.eos_token = Some(BeginEndUnkTok(Either::Left(eos.unwrap()))); + } + if chat_template.unk_token.is_none() && unk.is_some() { + chat_template.unk_token = Some(BeginEndUnkTok(Either::Left(unk.unwrap()))); + } + let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer); Ok(Arc::new(Mutex::new(GGUFPipeline { model, @@ -509,7 +538,10 @@ impl Loader for GGUFLoader { tokenizer: tokenizer.into(), no_kv_cache: self.no_kv_cache, chat_template: Arc::new(chat_template), - model_id: self.model_id.clone(), + model_id: self + .model_id + .clone() + .unwrap_or(self.quantized_model_id.clone()), non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| { NonGranularState { non_granular_index: Arc::new(Mutex::new(0)), @@ -532,7 +564,7 @@ impl Loader for GGUFLoader { fn get_id(&self) -> String { self.xlora_model_id .as_deref() - .unwrap_or(&self.model_id) + .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id)) .to_string() } diff --git a/mistralrs-core/src/pipeline/gguf_tokenizer.rs b/mistralrs-core/src/pipeline/gguf_tokenizer.rs index 1a8333616..1d6985c1f 100644 --- a/mistralrs-core/src/pipeline/gguf_tokenizer.rs +++ b/mistralrs-core/src/pipeline/gguf_tokenizer.rs @@ -12,7 +12,14 @@ use tracing::info; use crate::DEBUG; -pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { +pub struct ConversionResult { + pub tokenizer: Tokenizer, + pub bos: Option, + pub eos: Option, + pub unk: Option, +} + +pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { let model = content.metadata["tokenizer.ggml.model"] .to_string() .expect("GGUF tokenizer model is not a string.") @@ -67,6 +74,10 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { .to_u32() .expect("GGUF unk token is not u32"); + let bos_str = tokens[bos as usize].clone(); + let eos_str = tokens[eos as usize].clone(); + let unk_str = tokens[unk as usize].clone(); + let (tokenizer, ty) = match model.as_str() { "llama" | "replit" => { // unigram @@ -112,7 +123,12 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result { if DEBUG.load(Ordering::Relaxed) { info!("Tokenizer: {tokenizer:?}"); } - Ok(tokenizer) + Ok(ConversionResult { + tokenizer, + bos: Some(bos_str), + eos: Some(eos_str), + unk: Some(unk_str), + }) } mod tests { @@ -152,6 +168,7 @@ mod tests { .map_err(anyhow::Error::msg)?, ) .map_err(anyhow::Error::msg) + .map(|res| res.tokenizer) } other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"), } diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 7f8f663d5..6e29c940c 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -146,12 +146,14 @@ macro_rules! get_paths_gguf { .with_token(get_token($token_source)?) .build()?; let revision = $revision.unwrap_or("main".to_string()); + let model_id_this = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone()); + let model_id_copy = model_id_this.clone(); let api = api.repo(Repo::with_revision( - $this.model_id.clone(), + model_id_this.clone(), RepoType::Model, revision.clone(), )); - let model_id = std::path::Path::new(&$this.model_id); + let model_id = std::path::Path::new(&model_id_copy); let chat_template = if let Some(ref p) = $this.chat_template { if p.ends_with(".json") { @@ -171,8 +173,8 @@ macro_rules! get_paths_gguf { let filenames = get_model_paths( revision.clone(), &$token_source, - &$quantized_model_id, - &$quantized_filename, + &Some($quantized_model_id), + &Some($quantized_filename), &api, &model_id, )?; @@ -185,7 +187,7 @@ macro_rules! get_paths_gguf { xlora_config, lora_preload_adapter_info, } = get_xlora_paths( - $this.model_id.clone(), + model_id_this, &$this.xlora_model_id, &$token_source, revision.clone(), @@ -205,8 +207,14 @@ macro_rules! get_paths_gguf { None }; + let tokenizer_filename = if $this.model_id.is_some() { + $crate::api_get_file!(api, "tokenizer.json", model_id) + } else { + PathBuf::from_str("")? + }; + Ok(Box::new($path_name { - tokenizer_filename: PathBuf::from_str("")?, + tokenizer_filename, config_filename: PathBuf::from_str("")?, filenames, xlora_adapter_configs: adapter_configs, diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 9d7dbee83..5dae166a3 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -598,9 +598,13 @@ pub trait Pipeline: Send + Sync { } else { None }; - let eos_tok = match chat_template.eos_token { - Either::Left(ref lit) => lit, - Either::Right(ref added) => &added.content, + let eos_tok = if let Some(ref unk) = self.get_chat_template().eos_token { + match unk.0 { + Either::Left(ref lit) => Some(lit.to_string()), + Either::Right(ref added) => Some(added.content.to_string()), + } + } else { + None }; let unk_tok = if let Some(ref unk) = self.get_chat_template().unk_token { match unk.0 { @@ -1436,7 +1440,7 @@ mod tests { true, template, Some(bos.to_string()), - eos, + Some(eos.to_string()), Some(unk.to_string()), ) .unwrap_or_else(|_| panic!("Template number {i}")); diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index d105982bd..49027c7e0 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -52,4 +52,8 @@ required-features = [] [[example]] name = "lora_activation" +required-features = [] + +[[example]] +name = "gguf_locally" required-features = [] \ No newline at end of file diff --git a/mistralrs/examples/gguf_locally/main.rs b/mistralrs/examples/gguf_locally/main.rs new file mode 100644 index 000000000..b04fc9fa5 --- /dev/null +++ b/mistralrs/examples/gguf_locally/main.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; +use tokio::sync::mpsc::channel; + +use mistralrs::{ + Constraint, Device, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs, + MistralRsBuilder, NormalRequest, Request, RequestMessage, Response, SamplingParams, + SchedulerMethod, TokenSource, +}; + +fn setup() -> anyhow::Result> { + // Select a Mistral model + // We do not use any files from HF servers here, and instead load the + // chat template from the specified file, and the tokenizer and model from a + // local GGUF file at the path `.` + let loader = GGUFLoaderBuilder::new( + GGUFSpecificConfig { repeat_last_n: 64 }, + Some("chat_templates/mistral.json".to_string()), + None, + ".".to_string(), + "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(), + ) + .build(); + // Load, into a Pipeline + let pipeline = loader.load_model_from_hf( + None, + TokenSource::CacheToken, + None, + &Device::cuda_if_available(0)?, + false, + DeviceMapMetadata::dummy(), + None, + )?; + // Create the MistralRs, which is a runner + Ok(MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build()) +} + +fn main() -> anyhow::Result<()> { + let mistralrs = setup()?; + + let (tx, mut rx) = channel(10_000); + let request = Request::Normal(NormalRequest { + messages: RequestMessage::Completion { + text: "Hello! My name is ".to_string(), + echo_prompt: false, + best_of: 1, + }, + sampling_params: SamplingParams::default(), + response: tx, + return_logprobs: false, + is_streaming: false, + id: 0, + constraint: Constraint::None, + suffix: None, + adapters: None, + }); + mistralrs.get_sender().blocking_send(request)?; + + let response = rx.blocking_recv().unwrap(); + match response { + Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text), + _ => unreachable!(), + } + Ok(()) +} diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index 58f1ac92b..b6539edaf 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -9,6 +9,7 @@ use mistralrs::{ fn setup() -> anyhow::Result> { // Select a Mistral model + // This uses a model, tokenizer, and chat template, from HF hub. let loader = GGUFLoaderBuilder::new( GGUFSpecificConfig { repeat_last_n: 64 }, None,