Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example #357

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions chat_templates/llama2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
}
3 changes: 3 additions & 0 deletions chat_templates/llama3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
}
3 changes: 3 additions & 0 deletions chat_templates/mistral.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
}
3 changes: 3 additions & 0 deletions chat_templates/phi3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
}
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Examples
- Python: [examples here](python)
- HTTP Server: [examples here](server)
- Rust: [examples here](../mistralrs/examples/)
29 changes: 14 additions & 15 deletions mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
}

#[derive(Debug, Deserialize)]
pub struct Unk(#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>);
#[derive(Debug, Deserialize)]
pub struct Bos(#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>);
pub struct BeginEndUnkTok(
#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
);

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
Expand All @@ -41,23 +41,22 @@ pub struct ChatTemplate {
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
pub bos_token: Option<Bos>,
pub bos_token: Option<BeginEndUnkTok>,

/// Jinja format chat templating for chat completion.
/// See: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option<String>,
clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>,
#[serde(with = "either::serde_untagged")]
pub eos_token: Either<String, AddedTokensDecoder>,
pub eos_token: Option<BeginEndUnkTok>,
legacy: Option<bool>,
model_max_length: f64,
model_max_length: Option<f64>,
pad_token: Option<String>,
sp_model_kwargs: Option<HashMap<String, String>>,
spaces_between_special_tokens: Option<bool>,
tokenizer_class: String,
tokenizer_class: Option<String>,
truncation_size: Option<String>,
pub unk_token: Option<Unk>,
pub unk_token: Option<BeginEndUnkTok>,
use_default_system_prompt: Option<bool>,
}

Expand All @@ -66,10 +65,10 @@ impl ChatTemplate {
self.chat_template.is_some()
}

pub fn eos_tok(&self) -> String {
match self.eos_token {
Either::Left(ref lit) => lit.clone(),
Either::Right(ref added) => added.content.clone(),
pub fn eos_tok(&self) -> Option<String> {
match self.eos_token.as_ref()?.0 {
Either::Left(ref lit) => Some(lit.clone()),
Either::Right(ref added) => Some(added.content.clone()),
}
}

Expand All @@ -93,7 +92,7 @@ pub fn calculate_eos_tokens(
gen_conf: Option<GenerationConfig>,
tokenizer: &Tokenizer,
) -> Vec<u32> {
let mut eos_tok_ids = vec![chat_template.eos_tok()];
let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();

for alternate in SUPPORTED_ALTERNATE_EOS {
Expand Down Expand Up @@ -173,7 +172,7 @@ pub fn apply_chat_template_to(
add_generation_prompt: bool,
template: &str,
bos_tok: Option<String>,
eos_tok: &str,
eos_tok: Option<String>,
unk_tok: Option<String>,
) -> Result<String> {
let mut env = Environment::new();
Expand Down
72 changes: 52 additions & 20 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
use crate::pipeline::gguf_tokenizer::convert_ggml_to_hf_tokenizer;
use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig};
use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG};
Expand All @@ -28,6 +29,7 @@ use candle_core::quantized::{
GgmlDType,
};
use candle_core::{DType, Device, Tensor};
use either::Either;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::fs;
Expand Down Expand Up @@ -61,10 +63,10 @@ pub struct GGUFPipeline {
}

pub struct GGUFLoader {
model_id: String,
model_id: Option<String>,
config: GGUFSpecificConfig,
quantized_model_id: Option<String>,
quantized_filename: Option<String>,
quantized_model_id: String,
quantized_filename: String,
xlora_model_id: Option<String>,
xlora_order: Option<Ordering>,
no_kv_cache: bool,
Expand Down Expand Up @@ -189,16 +191,16 @@ impl GGUFLoaderBuilder {

pub fn build(self) -> Box<dyn Loader> {
Box::new(GGUFLoader {
model_id: self.model_id.unwrap(),
model_id: self.model_id,
config: self.config,
xlora_model_id: self.xlora_model_id,
kind: self.kind,
xlora_order: self.xlora_order,
no_kv_cache: self.no_kv_cache,
chat_template: self.chat_template,
tgt_non_granular_index: self.tgt_non_granular_index,
quantized_filename: Some(self.quantized_filename),
quantized_model_id: Some(self.quantized_model_id),
quantized_filename: self.quantized_filename,
quantized_model_id: self.quantized_model_id,
})
}
}
Expand All @@ -208,8 +210,8 @@ impl GGUFLoader {
pub fn new(
model_id: Option<String>,
config: GGUFSpecificConfig,
quantized_model_id: Option<String>,
quantized_filename: Option<String>,
quantized_model_id: String,
quantized_filename: String,
xlora_model_id: Option<String>,
kind: ModelKind,
xlora_order: Option<Ordering>,
Expand All @@ -218,13 +220,15 @@ impl GGUFLoader {
tgt_non_granular_index: Option<usize>,
) -> Self {
let model_id = if let Some(id) = model_id {
id
} else {
Some(id)
} else if let Some(xlora_order) = xlora_order.clone() {
info!(
"Using adapter base model ID: `{}`",
xlora_order.as_ref().unwrap().base_model_id
xlora_order.base_model_id
);
xlora_order.as_ref().unwrap().base_model_id.clone()
Some(xlora_order.base_model_id.clone())
} else {
None
};
Self {
model_id,
Expand Down Expand Up @@ -280,8 +284,8 @@ impl Loader for GGUFLoader {
&token_source,
revision,
self,
self.quantized_model_id,
self.quantized_filename,
self.quantized_model_id.clone(),
self.quantized_filename.clone(),
silent
);
self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant)
Expand Down Expand Up @@ -356,7 +360,21 @@ impl Loader for GGUFLoader {
info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}

let tokenizer = convert_ggml_to_hf_tokenizer(&model)?;
let ConversionResult {
tokenizer,
bos,
eos,
unk,
} = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
convert_ggml_to_hf_tokenizer(&model)?
} else {
ConversionResult {
tokenizer: get_tokenizer(paths.get_tokenizer_filename())?,
bos: None,
eos: None,
unk: None,
}
};
Comment on lines +363 to +377
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't want to deal with resolving conflicts again, this will be annoying for me to rebase against like the debug addition was.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and... you just merged it 😑

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry! I didn't see the comment in time. Is there anything I can do to help it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, but due to the number of commits and my change being early at the point straight after this addition, I get conflicts rippled through the history.

I'll probably be lazy and just use a merge commit to resolve instead of a proper rebase. I raised the PR before going to bed since there is a lot of activity recently on the files I was refactoring and it's starting to get tedious, so best to get what I've done so far in when you have time to review 👍

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good, a merge commit is fine. I'll make a review now and then maybe you can do the rebase/merge?


let mut is_lora = false;
let model = match self.kind {
Expand Down Expand Up @@ -481,7 +499,7 @@ impl Loader for GGUFLoader {
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);
let mut chat_template = get_chat_template(paths, &self.chat_template);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand All @@ -502,14 +520,28 @@ impl Loader for GGUFLoader {
Model::Phi3(ref model) => model.cache.lock().len(),
Model::XLoraPhi3(ref model) => model.cache.lock().len(),
};

if chat_template.bos_token.is_none() && bos.is_some() {
chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap())));
}
if chat_template.eos_token.is_none() && eos.is_some() {
chat_template.eos_token = Some(BeginEndUnkTok(Either::Left(eos.unwrap())));
}
if chat_template.unk_token.is_none() && unk.is_some() {
chat_template.unk_token = Some(BeginEndUnkTok(Either::Left(unk.unwrap())));
}

let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
Ok(Arc::new(Mutex::new(GGUFPipeline {
model,
tok_trie: tok_trie.clone(),
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
model_id: self.model_id.clone(),
model_id: self
.model_id
.clone()
.unwrap_or(self.quantized_model_id.clone()),
non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
NonGranularState {
non_granular_index: Arc::new(Mutex::new(0)),
Expand All @@ -532,7 +564,7 @@ impl Loader for GGUFLoader {
fn get_id(&self) -> String {
self.xlora_model_id
.as_deref()
.unwrap_or(&self.model_id)
.unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
.to_string()
}

Expand Down
21 changes: 19 additions & 2 deletions mistralrs-core/src/pipeline/gguf_tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ use tracing::info;

use crate::DEBUG;

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
pub struct ConversionResult {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
pub eos: Option<String>,
pub unk: Option<String>,
}

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResult> {
let model = content.metadata["tokenizer.ggml.model"]
.to_string()
.expect("GGUF tokenizer model is not a string.")
Expand Down Expand Up @@ -67,6 +74,10 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
.to_u32()
.expect("GGUF unk token is not u32");

let bos_str = tokens[bos as usize].clone();
let eos_str = tokens[eos as usize].clone();
let unk_str = tokens[unk as usize].clone();

let (tokenizer, ty) = match model.as_str() {
"llama" | "replit" => {
// unigram
Expand Down Expand Up @@ -112,7 +123,12 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<Tokenizer> {
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}
Ok(tokenizer)
Ok(ConversionResult {
tokenizer,
bos: Some(bos_str),
eos: Some(eos_str),
unk: Some(unk_str),
})
}

mod tests {
Expand Down Expand Up @@ -152,6 +168,7 @@ mod tests {
.map_err(anyhow::Error::msg)?,
)
.map_err(anyhow::Error::msg)
.map(|res| res.tokenizer)
}
other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
}
Expand Down
20 changes: 14 additions & 6 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,14 @@ macro_rules! get_paths_gguf {
.with_token(get_token($token_source)?)
.build()?;
let revision = $revision.unwrap_or("main".to_string());
let model_id_this = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
let model_id_copy = model_id_this.clone();
let api = api.repo(Repo::with_revision(
$this.model_id.clone(),
model_id_this.clone(),
RepoType::Model,
revision.clone(),
));
let model_id = std::path::Path::new(&$this.model_id);
let model_id = std::path::Path::new(&model_id_copy);

let chat_template = if let Some(ref p) = $this.chat_template {
if p.ends_with(".json") {
Expand All @@ -171,8 +173,8 @@ macro_rules! get_paths_gguf {
let filenames = get_model_paths(
revision.clone(),
&$token_source,
&$quantized_model_id,
&$quantized_filename,
&Some($quantized_model_id),
&Some($quantized_filename),
&api,
&model_id,
)?;
Expand All @@ -185,7 +187,7 @@ macro_rules! get_paths_gguf {
xlora_config,
lora_preload_adapter_info,
} = get_xlora_paths(
$this.model_id.clone(),
model_id_this,
&$this.xlora_model_id,
&$token_source,
revision.clone(),
Expand All @@ -205,8 +207,14 @@ macro_rules! get_paths_gguf {
None
};

let tokenizer_filename = if $this.model_id.is_some() {
$crate::api_get_file!(api, "tokenizer.json", model_id)
} else {
PathBuf::from_str("")?
};

Ok(Box::new($path_name {
tokenizer_filename: PathBuf::from_str("")?,
tokenizer_filename,
config_filename: PathBuf::from_str("")?,
filenames,
xlora_adapter_configs: adapter_configs,
Expand Down
12 changes: 8 additions & 4 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@
fn get_weight_filenames(&self) -> &[PathBuf];

/// Retrieve the PretrainedConfig file.
/// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig

Check warning on line 67 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

this URL is not a hyperlink
fn get_config_filename(&self) -> &PathBuf;

/// A serialised `tokenizers.Tokenizer` HuggingFace object.
/// See: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer

Check warning on line 71 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

this URL is not a hyperlink
fn get_tokenizer_filename(&self) -> &PathBuf;

/// Content expected to deserialize to [`ChatTemplate`].

Check warning on line 74 in mistralrs-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / Docs

public documentation for `get_template_filename` links to private item `ChatTemplate`
fn get_template_filename(&self) -> &PathBuf;

/// Optional adapter files. `(String, PathBuf)` is of the form `(id name, path)`.
Expand Down Expand Up @@ -598,9 +598,13 @@
} else {
None
};
let eos_tok = match chat_template.eos_token {
Either::Left(ref lit) => lit,
Either::Right(ref added) => &added.content,
let eos_tok = if let Some(ref unk) = self.get_chat_template().eos_token {
match unk.0 {
Either::Left(ref lit) => Some(lit.to_string()),
Either::Right(ref added) => Some(added.content.to_string()),
}
} else {
None
};
let unk_tok = if let Some(ref unk) = self.get_chat_template().unk_token {
match unk.0 {
Expand Down Expand Up @@ -1436,7 +1440,7 @@
true,
template,
Some(bos.to_string()),
eos,
Some(eos.to_string()),
Some(unk.to_string()),
)
.unwrap_or_else(|_| panic!("Template number {i}"));
Expand Down
Loading
Loading