diff --git a/chat_templates/llama2.json b/chat_templates/llama2.json
new file mode 100644
index 000000000..800a077f2
--- /dev/null
+++ b/chat_templates/llama2.json
@@ -0,0 +1,3 @@
+{
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
+}
\ No newline at end of file
diff --git a/chat_templates/llama3.json b/chat_templates/llama3.json
new file mode 100644
index 000000000..61bafeb2e
--- /dev/null
+++ b/chat_templates/llama3.json
@@ -0,0 +1,3 @@
+{
+ "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
+}
\ No newline at end of file
diff --git a/chat_templates/mistral.json b/chat_templates/mistral.json
new file mode 100644
index 000000000..15544fda6
--- /dev/null
+++ b/chat_templates/mistral.json
@@ -0,0 +1,3 @@
+{
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
+}
\ No newline at end of file
diff --git a/chat_templates/phi3.json b/chat_templates/phi3.json
new file mode 100644
index 000000000..6d92f29e6
--- /dev/null
+++ b/chat_templates/phi3.json
@@ -0,0 +1,3 @@
+{
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
+}
\ No newline at end of file
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 000000000..043a2211d
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,4 @@
+# Examples
+- Python: [examples here](python)
+- HTTP Server: [examples here](server)
+- Rust: [examples here](../mistralrs/examples/)
\ No newline at end of file
diff --git a/mistralrs-core/src/pipeline/chat_template.rs b/mistralrs-core/src/pipeline/chat_template.rs
index ee7dfa115..e419b8901 100644
--- a/mistralrs-core/src/pipeline/chat_template.rs
+++ b/mistralrs-core/src/pipeline/chat_template.rs
@@ -30,9 +30,9 @@ fn raise_exception(msg: String) -> Result {
}
#[derive(Debug, Deserialize)]
-pub struct Unk(#[serde(with = "either::serde_untagged")] pub Either);
-#[derive(Debug, Deserialize)]
-pub struct Bos(#[serde(with = "either::serde_untagged")] pub Either);
+pub struct BeginEndUnkTok(
+ #[serde(with = "either::serde_untagged")] pub Either,
+);
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
@@ -41,23 +41,22 @@ pub struct ChatTemplate {
add_eos_token: Option,
added_tokens_decoder: Option>,
additional_special_tokens: Option>,
- pub bos_token: Option,
+ pub bos_token: Option,
/// Jinja format chat templating for chat completion.
/// See: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option,
clean_up_tokenization_spaces: Option,
device_map: Option,
- #[serde(with = "either::serde_untagged")]
- pub eos_token: Either,
+ pub eos_token: Option,
legacy: Option,
- model_max_length: f64,
+ model_max_length: Option,
pad_token: Option,
sp_model_kwargs: Option>,
spaces_between_special_tokens: Option,
- tokenizer_class: String,
+ tokenizer_class: Option,
truncation_size: Option,
- pub unk_token: Option,
+ pub unk_token: Option,
use_default_system_prompt: Option,
}
@@ -66,10 +65,10 @@ impl ChatTemplate {
self.chat_template.is_some()
}
- pub fn eos_tok(&self) -> String {
- match self.eos_token {
- Either::Left(ref lit) => lit.clone(),
- Either::Right(ref added) => added.content.clone(),
+ pub fn eos_tok(&self) -> Option {
+ match self.eos_token.as_ref()?.0 {
+ Either::Left(ref lit) => Some(lit.clone()),
+ Either::Right(ref added) => Some(added.content.clone()),
}
}
@@ -93,7 +92,7 @@ pub fn calculate_eos_tokens(
gen_conf: Option,
tokenizer: &Tokenizer,
) -> Vec {
- let mut eos_tok_ids = vec![chat_template.eos_tok()];
+ let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
for alternate in SUPPORTED_ALTERNATE_EOS {
@@ -173,7 +172,7 @@ pub fn apply_chat_template_to(
add_generation_prompt: bool,
template: &str,
bos_tok: Option,
- eos_tok: &str,
+ eos_tok: Option,
unk_tok: Option,
) -> Result {
let mut env = Environment::new();
diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs
index ae3bb9dca..71520b1d6 100644
--- a/mistralrs-core/src/pipeline/gguf.rs
+++ b/mistralrs-core/src/pipeline/gguf.rs
@@ -6,12 +6,13 @@ use super::{
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::lora::Ordering;
-use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
-use crate::pipeline::gguf_tokenizer::convert_ggml_to_hf_tokenizer;
+use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig};
+use crate::pipeline::gguf_tokenizer::{convert_ggml_to_hf_tokenizer, ConversionResult};
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
+use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, DEBUG};
@@ -28,6 +29,7 @@ use candle_core::quantized::{
GgmlDType,
};
use candle_core::{DType, Device, Tensor};
+use either::Either;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::fs;
@@ -61,10 +63,10 @@ pub struct GGUFPipeline {
}
pub struct GGUFLoader {
- model_id: String,
+ model_id: Option,
config: GGUFSpecificConfig,
- quantized_model_id: Option,
- quantized_filename: Option,
+ quantized_model_id: String,
+ quantized_filename: String,
xlora_model_id: Option,
xlora_order: Option,
no_kv_cache: bool,
@@ -189,7 +191,7 @@ impl GGUFLoaderBuilder {
pub fn build(self) -> Box {
Box::new(GGUFLoader {
- model_id: self.model_id.unwrap(),
+ model_id: self.model_id,
config: self.config,
xlora_model_id: self.xlora_model_id,
kind: self.kind,
@@ -197,8 +199,8 @@ impl GGUFLoaderBuilder {
no_kv_cache: self.no_kv_cache,
chat_template: self.chat_template,
tgt_non_granular_index: self.tgt_non_granular_index,
- quantized_filename: Some(self.quantized_filename),
- quantized_model_id: Some(self.quantized_model_id),
+ quantized_filename: self.quantized_filename,
+ quantized_model_id: self.quantized_model_id,
})
}
}
@@ -208,8 +210,8 @@ impl GGUFLoader {
pub fn new(
model_id: Option,
config: GGUFSpecificConfig,
- quantized_model_id: Option,
- quantized_filename: Option,
+ quantized_model_id: String,
+ quantized_filename: String,
xlora_model_id: Option,
kind: ModelKind,
xlora_order: Option,
@@ -218,13 +220,15 @@ impl GGUFLoader {
tgt_non_granular_index: Option,
) -> Self {
let model_id = if let Some(id) = model_id {
- id
- } else {
+ Some(id)
+ } else if let Some(xlora_order) = xlora_order.clone() {
info!(
"Using adapter base model ID: `{}`",
- xlora_order.as_ref().unwrap().base_model_id
+ xlora_order.base_model_id
);
- xlora_order.as_ref().unwrap().base_model_id.clone()
+ Some(xlora_order.base_model_id.clone())
+ } else {
+ None
};
Self {
model_id,
@@ -280,8 +284,8 @@ impl Loader for GGUFLoader {
&token_source,
revision,
self,
- self.quantized_model_id,
- self.quantized_filename,
+ self.quantized_model_id.clone(),
+ self.quantized_filename.clone(),
silent
);
self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant)
@@ -356,7 +360,21 @@ impl Loader for GGUFLoader {
info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}
- let tokenizer = convert_ggml_to_hf_tokenizer(&model)?;
+ let ConversionResult {
+ tokenizer,
+ bos,
+ eos,
+ unk,
+ } = if paths.get_tokenizer_filename().to_string_lossy().is_empty() {
+ convert_ggml_to_hf_tokenizer(&model)?
+ } else {
+ ConversionResult {
+ tokenizer: get_tokenizer(paths.get_tokenizer_filename())?,
+ bos: None,
+ eos: None,
+ unk: None,
+ }
+ };
let mut is_lora = false;
let model = match self.kind {
@@ -481,7 +499,7 @@ impl Loader for GGUFLoader {
let gen_conf: Option = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
- let chat_template = get_chat_template(paths, &self.chat_template);
+ let mut chat_template = get_chat_template(paths, &self.chat_template);
let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
@@ -502,6 +520,17 @@ impl Loader for GGUFLoader {
Model::Phi3(ref model) => model.cache.lock().len(),
Model::XLoraPhi3(ref model) => model.cache.lock().len(),
};
+
+ if chat_template.bos_token.is_none() && bos.is_some() {
+ chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap())));
+ }
+ if chat_template.eos_token.is_none() && eos.is_some() {
+ chat_template.eos_token = Some(BeginEndUnkTok(Either::Left(eos.unwrap())));
+ }
+ if chat_template.unk_token.is_none() && unk.is_some() {
+ chat_template.unk_token = Some(BeginEndUnkTok(Either::Left(unk.unwrap())));
+ }
+
let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
Ok(Arc::new(Mutex::new(GGUFPipeline {
model,
@@ -509,7 +538,10 @@ impl Loader for GGUFLoader {
tokenizer: tokenizer.into(),
no_kv_cache: self.no_kv_cache,
chat_template: Arc::new(chat_template),
- model_id: self.model_id.clone(),
+ model_id: self
+ .model_id
+ .clone()
+ .unwrap_or(self.quantized_model_id.clone()),
non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
NonGranularState {
non_granular_index: Arc::new(Mutex::new(0)),
@@ -532,7 +564,7 @@ impl Loader for GGUFLoader {
fn get_id(&self) -> String {
self.xlora_model_id
.as_deref()
- .unwrap_or(&self.model_id)
+ .unwrap_or(self.model_id.as_ref().unwrap_or(&self.quantized_model_id))
.to_string()
}
diff --git a/mistralrs-core/src/pipeline/gguf_tokenizer.rs b/mistralrs-core/src/pipeline/gguf_tokenizer.rs
index 1a8333616..1d6985c1f 100644
--- a/mistralrs-core/src/pipeline/gguf_tokenizer.rs
+++ b/mistralrs-core/src/pipeline/gguf_tokenizer.rs
@@ -12,7 +12,14 @@ use tracing::info;
use crate::DEBUG;
-pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result {
+pub struct ConversionResult {
+ pub tokenizer: Tokenizer,
+ pub bos: Option,
+ pub eos: Option,
+ pub unk: Option,
+}
+
+pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result {
let model = content.metadata["tokenizer.ggml.model"]
.to_string()
.expect("GGUF tokenizer model is not a string.")
@@ -67,6 +74,10 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result {
.to_u32()
.expect("GGUF unk token is not u32");
+ let bos_str = tokens[bos as usize].clone();
+ let eos_str = tokens[eos as usize].clone();
+ let unk_str = tokens[unk as usize].clone();
+
let (tokenizer, ty) = match model.as_str() {
"llama" | "replit" => {
// unigram
@@ -112,7 +123,12 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result {
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}
- Ok(tokenizer)
+ Ok(ConversionResult {
+ tokenizer,
+ bos: Some(bos_str),
+ eos: Some(eos_str),
+ unk: Some(unk_str),
+ })
}
mod tests {
@@ -152,6 +168,7 @@ mod tests {
.map_err(anyhow::Error::msg)?,
)
.map_err(anyhow::Error::msg)
+ .map(|res| res.tokenizer)
}
other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
}
diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs
index 7f8f663d5..6e29c940c 100644
--- a/mistralrs-core/src/pipeline/macros.rs
+++ b/mistralrs-core/src/pipeline/macros.rs
@@ -146,12 +146,14 @@ macro_rules! get_paths_gguf {
.with_token(get_token($token_source)?)
.build()?;
let revision = $revision.unwrap_or("main".to_string());
+ let model_id_this = $this.model_id.clone().unwrap_or($this.quantized_model_id.clone());
+ let model_id_copy = model_id_this.clone();
let api = api.repo(Repo::with_revision(
- $this.model_id.clone(),
+ model_id_this.clone(),
RepoType::Model,
revision.clone(),
));
- let model_id = std::path::Path::new(&$this.model_id);
+ let model_id = std::path::Path::new(&model_id_copy);
let chat_template = if let Some(ref p) = $this.chat_template {
if p.ends_with(".json") {
@@ -171,8 +173,8 @@ macro_rules! get_paths_gguf {
let filenames = get_model_paths(
revision.clone(),
&$token_source,
- &$quantized_model_id,
- &$quantized_filename,
+ &Some($quantized_model_id),
+ &Some($quantized_filename),
&api,
&model_id,
)?;
@@ -185,7 +187,7 @@ macro_rules! get_paths_gguf {
xlora_config,
lora_preload_adapter_info,
} = get_xlora_paths(
- $this.model_id.clone(),
+ model_id_this,
&$this.xlora_model_id,
&$token_source,
revision.clone(),
@@ -205,8 +207,14 @@ macro_rules! get_paths_gguf {
None
};
+ let tokenizer_filename = if $this.model_id.is_some() {
+ $crate::api_get_file!(api, "tokenizer.json", model_id)
+ } else {
+ PathBuf::from_str("")?
+ };
+
Ok(Box::new($path_name {
- tokenizer_filename: PathBuf::from_str("")?,
+ tokenizer_filename,
config_filename: PathBuf::from_str("")?,
filenames,
xlora_adapter_configs: adapter_configs,
diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs
index 9d7dbee83..5dae166a3 100644
--- a/mistralrs-core/src/pipeline/mod.rs
+++ b/mistralrs-core/src/pipeline/mod.rs
@@ -598,9 +598,13 @@ pub trait Pipeline: Send + Sync {
} else {
None
};
- let eos_tok = match chat_template.eos_token {
- Either::Left(ref lit) => lit,
- Either::Right(ref added) => &added.content,
+ let eos_tok = if let Some(ref unk) = self.get_chat_template().eos_token {
+ match unk.0 {
+ Either::Left(ref lit) => Some(lit.to_string()),
+ Either::Right(ref added) => Some(added.content.to_string()),
+ }
+ } else {
+ None
};
let unk_tok = if let Some(ref unk) = self.get_chat_template().unk_token {
match unk.0 {
@@ -1436,7 +1440,7 @@ mod tests {
true,
template,
Some(bos.to_string()),
- eos,
+ Some(eos.to_string()),
Some(unk.to_string()),
)
.unwrap_or_else(|_| panic!("Template number {i}"));
diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml
index d105982bd..49027c7e0 100644
--- a/mistralrs/Cargo.toml
+++ b/mistralrs/Cargo.toml
@@ -52,4 +52,8 @@ required-features = []
[[example]]
name = "lora_activation"
+required-features = []
+
+[[example]]
+name = "gguf_locally"
required-features = []
\ No newline at end of file
diff --git a/mistralrs/examples/gguf_locally/main.rs b/mistralrs/examples/gguf_locally/main.rs
new file mode 100644
index 000000000..b04fc9fa5
--- /dev/null
+++ b/mistralrs/examples/gguf_locally/main.rs
@@ -0,0 +1,64 @@
+use std::sync::Arc;
+use tokio::sync::mpsc::channel;
+
+use mistralrs::{
+ Constraint, Device, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs,
+ MistralRsBuilder, NormalRequest, Request, RequestMessage, Response, SamplingParams,
+ SchedulerMethod, TokenSource,
+};
+
+fn setup() -> anyhow::Result> {
+ // Select a Mistral model
+ // We do not use any files from HF servers here, and instead load the
+ // chat template from the specified file, and the tokenizer and model from a
+ // local GGUF file at the path `.`
+ let loader = GGUFLoaderBuilder::new(
+ GGUFSpecificConfig { repeat_last_n: 64 },
+ Some("chat_templates/mistral.json".to_string()),
+ None,
+ ".".to_string(),
+ "mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
+ )
+ .build();
+ // Load, into a Pipeline
+ let pipeline = loader.load_model_from_hf(
+ None,
+ TokenSource::CacheToken,
+ None,
+ &Device::cuda_if_available(0)?,
+ false,
+ DeviceMapMetadata::dummy(),
+ None,
+ )?;
+ // Create the MistralRs, which is a runner
+ Ok(MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build())
+}
+
+fn main() -> anyhow::Result<()> {
+ let mistralrs = setup()?;
+
+ let (tx, mut rx) = channel(10_000);
+ let request = Request::Normal(NormalRequest {
+ messages: RequestMessage::Completion {
+ text: "Hello! My name is ".to_string(),
+ echo_prompt: false,
+ best_of: 1,
+ },
+ sampling_params: SamplingParams::default(),
+ response: tx,
+ return_logprobs: false,
+ is_streaming: false,
+ id: 0,
+ constraint: Constraint::None,
+ suffix: None,
+ adapters: None,
+ });
+ mistralrs.get_sender().blocking_send(request)?;
+
+ let response = rx.blocking_recv().unwrap();
+ match response {
+ Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text),
+ _ => unreachable!(),
+ }
+ Ok(())
+}
diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs
index 58f1ac92b..b6539edaf 100644
--- a/mistralrs/examples/quantized/main.rs
+++ b/mistralrs/examples/quantized/main.rs
@@ -9,6 +9,7 @@ use mistralrs::{
fn setup() -> anyhow::Result> {
// Select a Mistral model
+ // This uses a model, tokenizer, and chat template, from HF hub.
let loader = GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n: 64 },
None,