Skip to content

Commit

Permalink
Merge pull request #357 from EricLBuehler/examples
Browse files Browse the repository at this point in the history
Add an example
  • Loading branch information
EricLBuehler authored May 29, 2024
2 parents 71bdd2f + ddba24b commit 9273f2a
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 47 deletions.
3 changes: 3 additions & 0 deletions chat_templates/llama2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
}
3 changes: 3 additions & 0 deletions chat_templates/llama3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
}
3 changes: 3 additions & 0 deletions chat_templates/mistral.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
}
3 changes: 3 additions & 0 deletions chat_templates/phi3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
}
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Examples
- Python: [examples here](python)
- HTTP Server: [examples here](server)
- Rust: [examples here](../mistralrs/examples/)
29 changes: 14 additions & 15 deletions mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
}

#[derive(Debug, Deserialize)]
pub struct Unk(#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>);
#[derive(Debug, Deserialize)]
pub struct Bos(#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>);
pub struct BeginEndUnkTok(
#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
);

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
Expand All @@ -41,23 +41,22 @@ pub struct ChatTemplate {
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
pub bos_token: Option<Bos>,
pub bos_token: Option<BeginEndUnkTok>,

/// Jinja format chat templating for chat completion.
/// See: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option<String>,
clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>,
#[serde(with = "either::serde_untagged")]
pub eos_token: Either<String, AddedTokensDecoder>,
pub eos_token: Option<BeginEndUnkTok>,
legacy: Option<bool>,
model_max_length: f64,
model_max_length: Option<f64>,
pad_token: Option<String>,
sp_model_kwargs: Option<HashMap<String, String>>,
spaces_between_special_tokens: Option<bool>,
tokenizer_class: String,
tokenizer_class: Option<String>,
truncation_size: Option<String>,
pub unk_token: Option<Unk>,
pub unk_token: Option<BeginEndUnkTok>,
use_default_system_prompt: Option<bool>,
}

Expand All @@ -66,10 +65,10 @@ impl ChatTemplate {
self.chat_template.is_some()
}

pub fn eos_tok(&self) -> String {
match self.eos_token {
Either::Left(ref lit) => lit.clone(),
Either::Right(ref added) => added.content.clone(),
pub fn eos_tok(&self) -> Option<String> {
match self.eos_token.as_ref()?.0 {
Either::Left(ref lit) => Some(lit.clone()),
Either::Right(ref added) => Some(added.content.clone()),
}
}

Expand All @@ -93,7 +92,7 @@ pub fn calculate_eos_tokens(
gen_conf: Option<GenerationConfig>,
tokenizer: &Tokenizer,
) -> Vec<u32> {
let mut eos_tok_ids = vec![chat_template.eos_tok()];
let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();

for alternate in SUPPORTED_ALTERNATE_EOS {
Expand Down Expand Up @@ -173,7 +172,7 @@ pub fn apply_chat_template_to(
add_generation_prompt: bool,
template: &str,
bos_tok: Option<String>,
eos_tok: &str,
eos_tok: Option<String>,
unk_tok: Option<String>,
) -> Result<String> {
let mut env = Environment::new();
Expand Down
72 changes: 52 additions & 20 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::gguf_tokenizer::convert_ggml_to_hf_tokenizer;
use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig};
use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG};
Expand All @@ -28,6 +29,7 @@ use candle_core::quantized::{
GgmlDType,
};
use candle_core::{DType, Device, Tensor};
use either::Either;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::fs;
Expand Down Expand Up @@ -61,10 +63,10 @@ pub struct GGUFPipeline {
}

pub struct GGUFLoader {
model_id: String,
model_id: Option<String>,
config: GGUFSpecificConfig,
quantized_model_id: Option<String>,
quantized_filename: Option<String>,
quantized_model_id: String,
quantized_filename: String,
xlora_model_id: Option<String>,
xlora_order: Option<Ordering>,
no_kv_cache: bool,
Expand Down Expand Up @@ -189,16 +191,16 @@ impl GGUFLoaderBuilder {

pub fn build(self) -> Box<dyn Loader> {
Box::new(GGUFLoader {
model_id: self.model_id.unwrap(),
model_id: self.model_id,
config: self.config,
xlora_model_id: self.xlora_model_id,
kind: self.kind,
xlora_order: self.xlora_order,
no_kv_cache: self.no_kv_cache,
chat_template: self.chat_template,
tgt_non_granular_index: self.tgt_non_granular_index,
quantized_filename: Some(self.quantized_filename),
quantized_model_id: Some(self.quantized_model_id),
quantized_filename: self.quantized_filename,
quantized_model_id: self.quantized_model_id,
})
}
}
Expand All @@ -208,8 +210,8 @@ impl GGUFLoader {
pub fn new(
model_id: Option<String>,
config: GGUFSpecificConfig,
quantized_model_id: Option<String>,
quantized_filename: Option<String>,
quantized_model_id: String,
quantized_filename: String,
xlora_model_id: Option<String>,
kind: ModelKind,
xlora_order: Option<Ordering>,
Expand All @@ -218,13 +220,15 @@ impl GGUFLoader {
tgt_non_granular_index: Option<usize>,
) -> Self {
let model_id = if let Some(id) = model_id {
id
} else {
Some(id)
} else if let Some(xlora_order) = xlora_order.clone() {
info!(
"Using adapter base model ID: `{}`",
xlora_order.as_ref().unwrap().base_model_id
xlora_order.base_model_id
);
xlora_order.as_ref().unwrap().base_model_id.clone()
Some(xlora_order.base_model_id.clone())
} else {
None
};
Self {
model_id,
Expand Down Expand Up @@ -280,8 +284,8 @@ impl Loader for GGUFLoader {
&token_source,
revision,
self,
self.quantized_model_id,
self.quantized_filename,
self.quantized_model_id.clone(),
self.quantized_filename.clone(),
silent
);
self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant)
Expand Down Expand Up @@ -356,7 +360,21 @@ impl Loader for GGUFLoader {
info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}

let tokenizer = convert_ggml_to_hf_tokenizer(&model)?;
let ConversionResult {
tokenizer,
bos,
eos,
unk,
} = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
convert_ggml_to_hf_tokenizer(&model)?
} else {
ConversionResult {
tokenizer: get_tokenizer(paths.get_tokenizer_filename())?,
bos: None,
eos: None,
unk: None,
}
};

let mut is_lora = false;
let model = match self.kind {
Expand Down Expand Up @@ -481,7 +499,7 @@ impl Loader for GGUFLoader {
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);
let mut chat_template = get_chat_template(paths, &self.chat_template);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand All @@ -502,14 +520,28 @@ impl Loader for GGUFLoader {
Model::Phi3(ref model) => model.cache.lock().len(),
Model::XLoraPhi3(ref model) => model.cache.lock().len(),
};

if chat_template.bos_token.is_none() && bos.is_some() {
chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap())));
}
if chat_template.eos_token.is_none() && eos.is_some() {
chat_template.eos_token = Some(BeginEndUnkTok(Either::Left(eos.unwrap())));
}
if chat_template.unk_token.is_none() && unk.is_some() {
chat_template.unk_token = Some(BeginEndUnkTok(Either::Left(unk.unwrap())));
}

let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
Ok(Arc::new(Mutex::new(GGUFPipeline {
model,
tok_trie: tok_trie.clone(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
model_id: self.model_id.clone(),
model_id: self
.model_id
.clone()
.unwrap_or(self.quantized_model_id.clone()),
non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
NonGranularState {
non_granular_index: Arc::new(Mutex::new(0)),
Expand All @@ -532,7 +564,7 @@ impl Loader for GGUFLoader {
fn get_id(&self) -> String {
self.xlora_model_id
.as_deref()
.unwrap_or(&self.model_id)
.unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
.to_string()
}

Expand Down
21 changes: 19 additions & 2 deletions mistralrs-core/src/pipeline/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ use tracing::info;

use crate::DEBUG;

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
pub struct ConversionResult {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
pub eos: Option<String>,
pub unk: Option<String>,
}

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResult> {
let model = content.metadata["tokenizer.ggml.model"]
.to_string()
.expect("GGUF tokenizer model is not a string.")
Expand Down Expand Up @@ -67,6 +74,10 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
.to_u32()
.expect("GGUF unk token is not u32");

let bos_str = tokens[bos as usize].clone();
let eos_str = tokens[eos as usize].clone();
let unk_str = tokens[unk as usize].clone();

let (tokenizer, ty) = match model.as_str() {
"llama" | "replit" => {
// unigram
Expand Down Expand Up @@ -112,7 +123,12 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}
Ok(tokenizer)
Ok(ConversionResult {
tokenizer,
bos: Some(bos_str),
eos: Some(eos_str),
unk: Some(unk_str),
})
}

mod tests {
Expand Down Expand Up @@ -152,6 +168,7 @@ mod tests {
.map_err(anyhow::Error::msg)?,
)
.map_err(anyhow::Error::msg)
.map(|res| res.tokenizer)
}
other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
}
Expand Down
20 changes: 14 additions & 6 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,14 @@ macro_rules! get_paths_gguf {
.with_token(get_token($token_source)?)
.build()?;
let revision = $revision.unwrap_or("main".to_string());
let model_id_this = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
let model_id_copy = model_id_this.clone();
let api = api.repo(Repo::with_revision(
$this.model_id.clone(),
model_id_this.clone(),
RepoType::Model,
revision.clone(),
));
let model_id = std::path::Path::new(&$this.model_id);
let model_id = std::path::Path::new(&model_id_copy);

let chat_template = if let Some(ref p) = $this.chat_template {
if p.ends_with(".json") {
Expand All @@ -171,8 +173,8 @@ macro_rules! get_paths_gguf {
let filenames = get_model_paths(
revision.clone(),
&$token_source,
&$quantized_model_id,
&$quantized_filename,
&Some($quantized_model_id),
&Some($quantized_filename),
&api,
&model_id,
)?;
Expand All @@ -185,7 +187,7 @@ macro_rules! get_paths_gguf {
xlora_config,
lora_preload_adapter_info,
} = get_xlora_paths(
$this.model_id.clone(),
model_id_this,
&$this.xlora_model_id,
&$token_source,
revision.clone(),
Expand All @@ -205,8 +207,14 @@ macro_rules! get_paths_gguf {
None
};

let tokenizer_filename = if $this.model_id.is_some() {
$crate::api_get_file!(api, "tokenizer.json", model_id)
} else {
PathBuf::from_str("")?
};

Ok(Box::new($path_name {
tokenizer_filename: PathBuf::from_str("")?,
tokenizer_filename,
config_filename: PathBuf::from_str("")?,
filenames,
xlora_adapter_configs: adapter_configs,
Expand Down
12 changes: 8 additions & 4 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,13 @@ pub trait Pipeline: Send + Sync {
} else {
None
};
let eos_tok = match chat_template.eos_token {
Either::Left(ref lit) => lit,
Either::Right(ref added) => &added.content,
let eos_tok = if let Some(ref unk) = self.get_chat_template().eos_token {
match unk.0 {
Either::Left(ref lit) => Some(lit.to_string()),
Either::Right(ref added) => Some(added.content.to_string()),
}
} else {
None
};
let unk_tok = if let Some(ref unk) = self.get_chat_template().unk_token {
match unk.0 {
Expand Down Expand Up @@ -1436,7 +1440,7 @@ mod tests {
true,
template,
Some(bos.to_string()),
eos,
Some(eos.to_string()),
Some(unk.to_string()),
)
.unwrap_or_else(|_| panic!("Template number {i}"));
Expand Down
Loading

0 comments on commit 9273f2a

Please sign in to comment.