Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using GGUF chat template #388

Closed
wants to merge 11 commits into from
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
futures = "0.3"
clap = { version = "4.5.1", features = ["derive"] }
pyo3 = { version = "0.21.0", features = ["full", "extension-module"] }
pyo3 = { version = "0.21.0", features = ["full", "extension-module", "either"] }
tokio = { version = "1.36.0", features = ["full", "rt-multi-thread"] }
once_cell = "1.19.0"
image = "0.25.1"
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ please consider using the method demonstrated in examples below, where the token
**Supported GGUF tokenizer types**
- `llama`

Some GGUF models are very large and are sharded into multiple files. Mistral.rs supports this, and to use it, delimit the `.gguf` filenames with a space as such:

```bash
./mistralrs-server --chat-template <chat_template> gguf -m . -f "a.gguf b.gguf"
```

For the Python API, a list of strings is also accepted for this case.

## Run

To start a server serving Mistral GGUF on `localhost:1234`,
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl DeviceMapMetadata {
host_layers: None,
}
}
/// A device mapper to not map device.
pub fn dummy() -> Self {
Self {
device_layers: None,
Expand Down
20 changes: 20 additions & 0 deletions mistralrs-core/src/gguf/chat_template.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use tracing::info;

use super::Content;

// Get chat template from GGUF metadata if it exists.
pub fn get_gguf_chat_template<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Option<String> {
content
.get_metadata("tokenizer.chat_template")
.ok()
.map(|template| {
let template = template
.to_string()
.expect("Chat template must be a string")
.clone();
info!("Discovered and using GGUF chat template: `{template}`");
template
})
}
167 changes: 167 additions & 0 deletions mistralrs-core/src/gguf/content.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
use std::fs;

use anyhow::Context;
use candle_core::{
quantized::{
gguf_file::{self, Value},
QTensor,
},
Device, Result,
};
use indexmap::IndexMap;
use tracing::info;

use crate::{pipeline::GGUFArchitecture, DEBUG};

fn parse_gguf_value(value: &Value) -> String {
match value {
Value::Array(vs) => vs
.iter()
.map(parse_gguf_value)
.collect::<Vec<String>>()
.join(", "),
Value::Bool(b) => b.to_string(),
Value::F32(x) => x.to_string(),
Value::F64(x) => x.to_string(),
Value::I8(x) => x.to_string(),
Value::I16(x) => x.to_string(),
Value::I32(x) => x.to_string(),
Value::I64(x) => x.to_string(),
Value::String(x) => x.to_string(),
Value::U8(x) => x.to_string(),
Value::U16(x) => x.to_string(),
Value::U32(x) => x.to_string(),
Value::U64(x) => x.to_string(),
}
}

// Internal invariant: contents and readers must be paired.
/// This abstracts the files for a GGUF model and enables multiple files to be used.
pub struct Content<'a, R: std::io::Seek + std::io::Read> {
contents: Vec<gguf_file::Content>,
readers: &'a mut [&'a mut R],
arch: GGUFArchitecture,
}

impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> {
/// Create a `Content` from a set of file readers.
pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> {
let mut contents = Vec::new();
let n_readers = readers.len();
for reader in readers.iter_mut() {
contents.push(gguf_file::Content::read(reader)?);
}
let n_splits = contents
.iter()
.filter_map(|ct| {
ct.metadata
.get("split.count")
.map(|val| val.to_u64().unwrap())
})
.collect::<Vec<_>>();
if n_splits.len() > 1 {
candle_core::bail!("Multiple contents have multiple `split.count` fields");
}
#[allow(clippy::cast_possible_truncation)]
if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
candle_core::bail!("Number of readers does not match the number of splits.");
} else if n_splits.len() == 1 {
info!("Model n splits: {}", n_splits[0]);
}

let mut arch = None;
for ct in &contents {
if !ct.metadata.contains_key("general.architecture") {
continue;
}

arch = Some(
ct.metadata["general.architecture"]
.to_string()
.context("Model metadata should have declared an architecture")
.and_then(GGUFArchitecture::from_value)
.unwrap(),
);
}
let arch = arch.expect("GGUF files must specify `general.architecture`");
Ok(Self {
contents,
readers,
arch,
})
}

pub fn arch(&self) -> GGUFArchitecture {
self.arch
}

/// Retrieve a tensor, searching through each content.
pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> {
for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) {
if let Some(tensor_info) = ct.tensor_infos.get(name) {
return tensor_info.read(reader, ct.tensor_data_offset, device);
}
}
candle_core::bail!("Cannot find tensor info for {name}")
}

/// Print metadata for these contents.
/// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled.
pub fn print_metadata(&self) -> anyhow::Result<()> {
// Find the ct with general.architecture
let mut keys = Vec::new();
let mut metadatas = Vec::new();
let mut tensors = Vec::new();
for ct in &self.contents {
keys.extend(ct.metadata.keys());
metadatas.push(&ct.metadata);

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
for (name, info) in &ct.tensor_infos {
tensors.push(format!(
"name = `{name}`, shape = {:?}, dtype = {:?}",
info.shape.clone(),
info.ggml_dtype
));
}
}
}

info!("Model config:");
keys.sort();
let mut output_keys = IndexMap::new();
for name in keys {
if !name.contains("tokenizer") {
for metadata in &metadatas {
if let Some(val) = metadata.get(name) {
output_keys.insert(name, parse_gguf_value(val));
}
}
}
}
for (name, val) in output_keys {
println!("{name}: {val}")
}

if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
fs::write(
"mistralrs_gguf_tensors.txt",
serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
)?;

info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
}

anyhow::Ok(())
}

/// Get metadata
pub fn get_metadata(&self, name: &str) -> Result<&Value> {
for content in &self.contents {
if let Some(v) = content.metadata.get(name) {
return Ok(v);
}
}
candle_core::bail!("Cannot find metadata for {name}")
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::atomic::Ordering;

use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use tokenizers::{
decoders::{self, byte_fallback::ByteFallback, fuse::Fuse, strip::Strip},
models::unigram::Unigram,
Expand All @@ -12,27 +11,33 @@ use tracing::info;

use crate::DEBUG;

pub struct ConversionResult {
use super::Content;

pub struct GgufTokenizerConversion {
pub tokenizer: Tokenizer,
pub bos: Option<String>,
pub eos: Option<String>,
pub unk: Option<String>,
}

pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResult> {
let model = content.metadata["tokenizer.ggml.model"]
// Convert GGUF tokenizer to tokenizer and metadata
pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
content: &Content<'_, R>,
) -> Result<GgufTokenizerConversion> {
let model = content
.get_metadata("tokenizer.ggml.model")?
.to_string()
.expect("GGUF tokenizer model is not a string.")
.clone();
let tokens = content.metadata["tokenizer.ggml.tokens"]
let tokens = content
.get_metadata("tokenizer.ggml.tokens")?
.to_vec()
.expect("GGUF tokenizer tokens is not a vec.")
.iter()
.map(|t| t.to_string().expect("GGUF token is not a string.").clone())
.collect::<Vec<_>>();
let added_tokens = content
.metadata
.get("tokenizer.ggml.added_tokens")
.get_metadata("tokenizer.ggml.added_tokens")
.map(|items| {
items
.to_vec()
Expand All @@ -45,15 +50,15 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
})
.collect::<Vec<_>>()
});
let scores = content.metadata.get("tokenizer.ggml.scores").map(|items| {
let scores = content.get_metadata("tokenizer.ggml.scores").map(|items| {
items
.to_vec()
.expect("GGUF tokenizer scores is not a vec.")
.iter()
.map(|t| t.to_f32().expect("GGUF score is not a f32."))
.collect::<Vec<_>>()
});
let merges = content.metadata.get("tokenizer.ggml.merges").map(|items| {
let merges = content.get_metadata("tokenizer.ggml.merges").map(|items| {
items
.to_vec()
.expect("GGUF tokenizer merges is not a vec.")
Expand All @@ -63,15 +68,16 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
});

let unk = content
.metadata
.get("tokenizer.ggml.unknown_token_id")
.get_metadata("tokenizer.ggml.unknown_token_id")
.map(|t| t.to_u32().expect("GGUF unk token is not u32"));

let eos = content.metadata["tokenizer.ggml.eos_token_id"]
let eos = content
.get_metadata("tokenizer.ggml.eos_token_id")?
.to_u32()
.expect("GGUF unk token is not u32");

let bos = content.metadata["tokenizer.ggml.bos_token_id"]
let bos = content
.get_metadata("tokenizer.ggml.bos_token_id")?
.to_u32()
.expect("GGUF unk token is not u32");

Expand Down Expand Up @@ -128,7 +134,7 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
if DEBUG.load(Ordering::Relaxed) {
info!("Tokenizer: {tokenizer:?}");
}
Ok(ConversionResult {
Ok(GgufTokenizerConversion {
tokenizer,
bos: Some(bos_str),
eos: Some(eos_str),
Expand All @@ -137,12 +143,12 @@ pub fn convert_ggml_to_hf_tokenizer(content: &Content) -> Result<ConversionResul
}

mod tests {
use crate::gguf::Content;
use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;

use super::convert_ggml_to_hf_tokenizer;
use super::convert_gguf_to_hf_tokenizer;

#[allow(dead_code)]
#[derive(Debug)]
Expand All @@ -167,8 +173,8 @@ mod tests {

let filename = api.get("mistral-7b-instruct-v0.1.Q2_K.gguf").unwrap();
let mut file = std::fs::File::open(&filename)?;
convert_ggml_to_hf_tokenizer(
&Content::read(&mut file)
convert_gguf_to_hf_tokenizer(
&Content::from_readers(&mut [&mut file])
.map_err(|e| e.with_path(filename))
.map_err(anyhow::Error::msg)?,
)
Expand Down
7 changes: 7 additions & 0 deletions mistralrs-core/src/gguf/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod chat_template;
mod content;
mod gguf_tokenizer;

pub use chat_template::get_gguf_chat_template;
pub use content::Content;
pub use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion};
11 changes: 5 additions & 6 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
};

use candle_core::{
quantized::{gguf_file, QMatMul, QTensor},
quantized::{QMatMul, QTensor},
DType, Device, IndexOp, Result, Tensor,
};
use candle_nn::{Linear, Module, VarBuilder};
Expand All @@ -20,7 +20,7 @@ use either::Either;
pub use crate::layers_masker::CausalMasker;
pub use crate::layers_utils::{flash_attn, repeat_kv, verify_sanity_gguf};

use crate::{cublaslt::CUBLASLT_HANDLE, INHIBIT_GEMM_F16};
use crate::{cublaslt::CUBLASLT_HANDLE, gguf::Content, INHIBIT_GEMM_F16};

#[derive(Debug, Clone)]
pub struct RmsNorm {
Expand Down Expand Up @@ -407,13 +407,12 @@ pub struct QLinear {

impl QLinear {
pub fn new<R: std::io::Read + std::io::Seek>(
ct: &gguf_file::Content,
r: &mut R,
ct: &mut Content<'_, R>,
name: &str,
device: &Device,
) -> Result<Self> {
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
let b = ct.tensor(r, &format!("{name}.bias"), device)?;
let w = ct.tensor(&format!("{name}.weight"), device)?;
let b = ct.tensor(&format!("{name}.bias"), device)?;
let inner = QMatMul::from_qtensor(w)?;
let bias = b.dequantize(device)?;
Ok(Self {
Expand Down
Loading
Loading