,
config: GGUFSpecificConfig,
diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs
index 4792dc0e85..28efe62917 100644
--- a/mistralrs-core/src/pipeline/macros.rs
+++ b/mistralrs-core/src/pipeline/macros.rs
@@ -1,3 +1,4 @@
+#[doc(hidden)]
#[macro_export]
macro_rules! api_dir_list {
($api:expr, $model_id:expr) => {
@@ -44,6 +45,7 @@ macro_rules! api_dir_list {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! api_get_file {
($api:expr, $file:expr, $model_id:expr) => {
@@ -73,6 +75,7 @@ macro_rules! api_get_file {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! get_paths {
($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{
@@ -186,6 +189,7 @@ macro_rules! get_paths {
}};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! get_paths_gguf {
($path_name:ident, $token_source:expr, $revision:expr, $this:expr, $quantized_model_id:expr, $quantized_filename:expr, $silent:expr) => {{
@@ -309,6 +313,7 @@ macro_rules! get_paths_gguf {
}};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! normal_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
@@ -333,6 +338,7 @@ macro_rules! normal_model_loader {
}};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! vision_normal_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
@@ -348,13 +354,16 @@ macro_rules! vision_normal_model_loader {
&$config,
$use_flash_attn,
vb,
- $mapper,
- $loading_isq,
- $real_device,
+ $crate::pipeline::NormalLoadingMetadata {
+ mapper: $mapper,
+ loading_isq: $loading_isq,
+ real_device: $real_device,
+ },
)?
}};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! xlora_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
@@ -399,6 +408,7 @@ macro_rules! xlora_model_loader {
}};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! lora_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs
index 93a6201c75..18c2299f63 100644
--- a/mistralrs-core/src/pipeline/mod.rs
+++ b/mistralrs-core/src/pipeline/mod.rs
@@ -40,7 +40,7 @@ use std::{collections::HashMap, path::PathBuf, str::FromStr};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
-pub use vision_loaders::{VisionLoaderType, VisionModelLoader};
+pub use vision_loaders::{Phi3VLoader, VisionLoaderType, VisionModelLoader};
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
@@ -103,6 +103,7 @@ pub trait ModelPaths {
}
#[derive(Clone)]
+/// All local paths and metadata necessary to load a model.
pub struct LocalModelPaths {
tokenizer_filename: P,
config_filename: P,
diff --git a/mistralrs-core/src/pipeline/normal_loaders.rs b/mistralrs-core/src/pipeline/normal_loaders.rs
index ce49be5c21..31eb23a677 100644
--- a/mistralrs-core/src/pipeline/normal_loaders.rs
+++ b/mistralrs-core/src/pipeline/normal_loaders.rs
@@ -223,6 +223,9 @@ impl GemmaBasicConfig {
}
}
+/// [`NormalLoader`] for a Gemma model.
+///
+/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
pub struct GemmaLoader;
impl NormalModelLoader for GemmaLoader {
@@ -313,6 +316,9 @@ impl LlamaBasicConfig {
}
}
+/// [`NormalLoader`] for a Llama model.
+///
+/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
pub struct LlamaLoader;
impl NormalModelLoader for LlamaLoader {
@@ -495,6 +501,9 @@ impl Phi2BasicConfig {
}
}
+/// [`NormalLoader`] for a Phi 2 model.
+///
+/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
pub struct Phi2Loader;
impl NormalModelLoader for Phi2Loader {
@@ -593,6 +602,9 @@ impl Phi3BasicConfig {
}
}
+/// [`NormalLoader`] for a Phi 3 model.
+///
+/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
pub struct Phi3Loader;
impl NormalModelLoader for Phi3Loader {
@@ -686,6 +698,9 @@ impl Qwen2BasicConfig {
}
}
+/// [`NormalLoader`] for a Qwen 2 model.
+///
+/// [`NormalLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.NormalLoader.html
pub struct Qwen2Loader;
impl NormalModelLoader for Qwen2Loader {
diff --git a/mistralrs-core/src/pipeline/sampling_pipeline.rs b/mistralrs-core/src/pipeline/sampling_pipeline.rs
index d569defae6..b628a99489 100644
--- a/mistralrs-core/src/pipeline/sampling_pipeline.rs
+++ b/mistralrs-core/src/pipeline/sampling_pipeline.rs
@@ -1,3 +1,4 @@
+#[doc(hidden)]
#[macro_export]
macro_rules! finish_and_add_tokens_to_seq {
($this:expr, $prefix_cacher:expr, $seq:expr, $logprobs:expr, $eos_tok:expr, $use_prefix_cacher:expr) => {{
@@ -177,6 +178,7 @@ macro_rules! finish_and_add_tokens_to_seq {
}
/// Sample and add to the prefix cache.
+#[doc(hidden)]
#[macro_export]
macro_rules! do_sample {
($this:expr, $seqs:expr, $logits:expr, $prefix_cacher:expr, $disable_eos_stop:expr, $rng:expr) => {{
diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs
index 373bc060f4..18fbaf0a01 100644
--- a/mistralrs-core/src/pipeline/speculative.rs
+++ b/mistralrs-core/src/pipeline/speculative.rs
@@ -26,6 +26,7 @@ use super::{
IsqPipelineMixin, MetadataMixin, ModelCategory, ModelPaths, PreProcessingMixin,
};
+/// A loader for a speculative pipeline using 2 [`Loader`]s.
pub struct SpeculativeLoader {
pub target: Box,
pub draft: Box,
@@ -138,6 +139,7 @@ pub struct SpeculativePipeline {
}
#[derive(Copy, Clone)]
+/// Metadata for a speculative pipeline
pub struct SpeculativeConfig {
/// γ completions to run of the draft model
pub gamma: usize,
diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs
index d009e7a433..d586a377b5 100644
--- a/mistralrs-core/src/pipeline/vision.rs
+++ b/mistralrs-core/src/pipeline/vision.rs
@@ -1,5 +1,5 @@
use super::cache_manager::DefaultCacheManager;
-use super::vision_loaders::{Phi3Loader, VisionLoaderType};
+use super::vision_loaders::{Phi3VLoader, VisionLoaderType};
use super::{
get_model_paths, get_xlora_paths, AdapterActivationMixin, Cache, CacheManager,
CacheManagerMixin, GeneralMetadata, IsqPipelineMixin, Loader, MetadataMixin, ModelCategory,
@@ -95,7 +95,7 @@ impl VisionLoaderBuilder {
setup_logger_and_debug();
let loader: Box = match loader {
- VisionLoaderType::Phi3V => Box::new(Phi3Loader),
+ VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
};
Box::new(VisionLoader {
inner: loader,
diff --git a/mistralrs-core/src/pipeline/vision_loaders.rs b/mistralrs-core/src/pipeline/vision_loaders.rs
index 191fcf47d6..de5fb3b438 100644
--- a/mistralrs-core/src/pipeline/vision_loaders.rs
+++ b/mistralrs-core/src/pipeline/vision_loaders.rs
@@ -2,7 +2,6 @@ use std::sync::Arc;
use std::{fmt::Debug, str::FromStr};
use anyhow::Result;
-use candle_core::Device;
use candle_nn::VarBuilder;
#[cfg(feature = "pyo3_macros")]
@@ -10,12 +9,11 @@ use pyo3::pyclass;
use serde::Deserialize;
-use super::{Processor, ProcessorCreator, VisionModel};
+use super::{NormalLoadingMetadata, Processor, ProcessorCreator, VisionModel};
use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3};
use crate::vision_models::phi3_inputs_processor::Phi3Processor;
use crate::vision_models::preprocessor_config::PreProcessorConfig;
use crate::vision_models::processor_config::ProcessorConfig;
-use crate::DeviceMapMetadata;
pub trait VisionModelLoader {
fn load(
@@ -23,9 +21,7 @@ pub trait VisionModelLoader {
config: &str,
use_flash_attn: bool,
vb: VarBuilder,
- mapper: DeviceMapMetadata,
- loading_isq: bool,
- device: Device,
+ normal_loading_metadata: NormalLoadingMetadata,
) -> Result>;
fn is_gptx(&self) -> bool;
fn get_config_repr(&self, config: &str, use_flash_attn: bool) -> Result>;
@@ -56,17 +52,18 @@ impl FromStr for VisionLoaderType {
// ======================== Phi 3 loader
-pub struct Phi3Loader;
+/// [`VisionLoader`] for a Phi 3 Vision model.
+///
+/// [`VisionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.VisionLoader.html
+pub struct Phi3VLoader;
-impl VisionModelLoader for Phi3Loader {
+impl VisionModelLoader for Phi3VLoader {
fn load(
&self,
config: &str,
use_flash_attn: bool,
vb: VarBuilder,
- mapper: DeviceMapMetadata,
- loading_isq: bool,
- device: Device,
+ normal_loading_metadata: NormalLoadingMetadata,
) -> Result> {
let mut config: Phi3Config = serde_json::from_str(config)?;
config.use_flash_attn = use_flash_attn;
@@ -74,9 +71,7 @@ impl VisionModelLoader for Phi3Loader {
&config,
vb,
self.is_gptx(),
- mapper,
- loading_isq,
- device,
+ normal_loading_metadata,
)?))
}
fn is_gptx(&self) -> bool {
diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs
index 9892fa1808..15cea90571 100644
--- a/mistralrs-core/src/request.rs
+++ b/mistralrs-core/src/request.rs
@@ -33,6 +33,7 @@ pub enum RequestMessage {
}
#[derive(Clone)]
+/// A normal request request to the `MistralRs`
pub struct NormalRequest {
pub messages: RequestMessage,
pub sampling_params: SamplingParams,
diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs
index fc4f4c93de..e0cad0a792 100644
--- a/mistralrs-core/src/response.rs
+++ b/mistralrs-core/src/response.rs
@@ -23,6 +23,7 @@ macro_rules! generate_repr {
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Chat completion response message.
pub struct ResponseMessage {
pub content: String,
pub role: String,
@@ -33,6 +34,7 @@ generate_repr!(ResponseMessage);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Delta in content for streaming response.
pub struct Delta {
pub content: String,
pub role: String,
@@ -43,6 +45,7 @@ generate_repr!(Delta);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// A logprob with the top logprobs for this token.
pub struct ResponseLogprob {
pub token: String,
pub logprob: f32,
@@ -55,6 +58,7 @@ generate_repr!(ResponseLogprob);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Logprobs per token.
pub struct Logprobs {
pub content: Option>,
}
@@ -64,6 +68,7 @@ generate_repr!(Logprobs);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Chat completion choice.
pub struct Choice {
pub finish_reason: String,
pub index: usize,
@@ -76,6 +81,7 @@ generate_repr!(Choice);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Completion streaming chunk choice.
pub struct ChunkChoice {
pub finish_reason: Option,
pub index: usize,
@@ -122,6 +128,7 @@ generate_repr!(ChatCompletionResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Chat completion streaming request chunk.
pub struct ChatCompletionChunkResponse {
pub id: String,
pub choices: Vec,
@@ -136,6 +143,7 @@ generate_repr!(ChatCompletionChunkResponse);
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize)]
+/// Completion request choice.
pub struct CompletionChoice {
pub finish_reason: String,
pub index: usize,
diff --git a/mistralrs-core/src/sampler.rs b/mistralrs-core/src/sampler.rs
index a8da56c100..34ad8bae95 100644
--- a/mistralrs-core/src/sampler.rs
+++ b/mistralrs-core/src/sampler.rs
@@ -70,7 +70,7 @@ pub struct Sampler {
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
-// Top-n logprobs element
+/// Top-n logprobs element
pub struct TopLogprob {
pub token: u32,
pub logprob: f32,
diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs
index 314f2492ef..314007140a 100644
--- a/mistralrs-core/src/utils/mod.rs
+++ b/mistralrs-core/src/utils/mod.rs
@@ -6,6 +6,7 @@ pub(crate) mod tokenizer;
pub(crate) mod tokens;
pub(crate) mod varbuilder_utils;
+#[doc(hidden)]
#[macro_export]
macro_rules! get_mut_arcmutex {
($thing:expr) => {
@@ -17,6 +18,7 @@ macro_rules! get_mut_arcmutex {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! handle_seq_error {
($fallible:expr, $response:expr) => {
@@ -34,6 +36,7 @@ macro_rules! handle_seq_error {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! handle_seq_error_ok {
($fallible:expr, $response:expr) => {
@@ -51,6 +54,7 @@ macro_rules! handle_seq_error_ok {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! handle_seq_error_stateaware_ok {
($fallible:expr, $seq:expr) => {
@@ -70,6 +74,7 @@ macro_rules! handle_seq_error_stateaware_ok {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! handle_pipeline_forward_error {
($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
@@ -177,6 +182,7 @@ macro_rules! handle_pipeline_forward_error {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! get_mut_group {
($this:expr) => {
@@ -188,6 +194,7 @@ macro_rules! get_mut_group {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! get_bias_if_not_allowed {
($tok_trie:expr, $rx:expr, $next_token_id:expr) => {
@@ -201,6 +208,7 @@ macro_rules! get_bias_if_not_allowed {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! sample_async {
(
@@ -235,6 +243,7 @@ macro_rules! sample_async {
};
}
+#[doc(hidden)]
#[macro_export]
macro_rules! serde_default_fn {
($t:ty, $name:ident, $v:expr) => {
diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs
index 970804e33b..3ce4127514 100644
--- a/mistralrs-core/src/vision_models/phi3.rs
+++ b/mistralrs-core/src/vision_models/phi3.rs
@@ -16,10 +16,11 @@ use crate::{
repeat_kv, CausalMasker, FusedBiasLinear, MatMul, Nonzero, PhiRopeConfig,
PhiRotaryEmbedding, RmsNorm, ScaledDotProductAttention,
},
- pipeline::{extract_logits, Cache, IsqModel, Phi3RopeScaling, VisionModel},
+ pipeline::{
+ extract_logits, Cache, IsqModel, NormalLoadingMetadata, Phi3RopeScaling, VisionModel,
+ },
serde_default_fn,
vision_models::clip::{Activation, ClipConfig, ClipVisionTransformer},
- DeviceMapMetadata,
};
#[derive(Debug, Clone, serde::Deserialize)]
@@ -699,9 +700,8 @@ impl ImageEmbedding {
// hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
let p_0 = positions.i((idx, 0))?.to_scalar::()? as usize;
let p_1 = positions.i((idx, 1))?.to_scalar::()? as usize;
- // TODO(EricLBuehler): https://github.com/huggingface/candle/pull/2223 will make this nicer
hidden_states = hidden_states.slice_assign(
- &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
+ &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
&img_set_tensor,
)?;
idx += cnt;
@@ -720,9 +720,8 @@ impl ImageEmbedding {
let p_0 = positions.i((idx, 0))?.to_scalar::()? as usize;
let p_1 = positions.i((idx, 1))?.to_scalar::()? as usize;
// hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = ...
- // TODO(EricLBuehler): https://github.com/huggingface/candle/pull/2223 will make this nicer
hidden_states = hidden_states.slice_assign(
- &[p_0..p_0 + 1, p_1..p_1 + cnt, 0..img_set_tensor.dims()[2]],
+ &[&p_0, &(p_1..p_1 + cnt), &(..img_set_tensor.dims()[2])],
&img_set_tensor,
)?;
idx += cnt;
@@ -757,12 +756,12 @@ impl Model {
cfg: &Config,
vb: VarBuilder,
_is_gptx: bool,
- mapper: DeviceMapMetadata,
- loading_isq: bool,
- real_device: Device,
+ normal_loading_metadata: NormalLoadingMetadata,
) -> Result {
let vb_m = vb.pp("model");
- let mapper = mapper.into_mapper(cfg.num_hidden_layers, &real_device)?;
+ let mapper = normal_loading_metadata
+ .mapper
+ .into_mapper(cfg.num_hidden_layers, &normal_loading_metadata.real_device)?;
let embed_tokens = candle_nn::embedding(
cfg.vocab_size,
cfg.hidden_size,
@@ -772,7 +771,7 @@ impl Model {
cfg,
embed_tokens.clone(),
&cfg.embd_layer,
- vb_m.pp("vision_embed_tokens"),
+ mapper.set_nm_device(vb_m.pp("vision_embed_tokens"), false),
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
@@ -780,7 +779,9 @@ impl Model {
let rotary_emb = Arc::new(PhiRotaryEmbedding::new(
vb.dtype(),
cfg.clone(),
- mapper.device_for(layer_idx, false).unwrap_or(&real_device),
+ mapper
+ .device_for(layer_idx, false)
+ .unwrap_or(&normal_loading_metadata.real_device),
)?);
let layer = DecoderLayer::new(
rotary_emb.clone(),
@@ -788,7 +789,7 @@ impl Model {
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
- loading_isq,
+ normal_loading_metadata.loading_isq,
)?;
layers.push(layer)
}
@@ -800,14 +801,14 @@ impl Model {
let lm_head = linear_no_bias(
cfg.hidden_size,
cfg.vocab_size,
- mapper.set_nm_device(vb.pp("lm_head"), loading_isq),
+ mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
)?;
Ok(Self {
vision_embed_tokens,
layers,
norm,
lm_head: QMatMul::Tensor(lm_head.weight().clone()),
- device: real_device,
+ device: normal_loading_metadata.real_device,
cache: Cache::new(cfg.num_hidden_layers, false),
max_seq_len: cfg.max_position_embeddings,
mapper,
diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs
index 767040d243..b9ddb59a33 100644
--- a/mistralrs-core/src/xlora_models/quantized_phi3.rs
+++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs
@@ -141,7 +141,7 @@ impl LayerWeights {
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
@@ -323,8 +323,8 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights {
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
- cos: cos.clone(),
- sin: sin.clone(),
+ cos: cos.to_device(device)?,
+ sin: sin.to_device(device)?,
sliding_window: context_window,
})
}
diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml
index dffc347310..e1e0171386 100644
--- a/mistralrs-pyo3/Cargo.toml
+++ b/mistralrs-pyo3/Cargo.toml
@@ -17,7 +17,7 @@ doc = false
[dependencies]
pyo3.workspace = true
-mistralrs-core = { version = "0.1.15", path = "../mistralrs-core", features = ["pyo3_macros"] }
+mistralrs-core = { version = "0.1.16", path = "../mistralrs-core", features = ["pyo3_macros"] }
serde.workspace = true
serde_json.workspace = true
candle-core.workspace = true
diff --git a/mistralrs-pyo3/Cargo_template.toml b/mistralrs-pyo3/Cargo_template.toml
index 683239db7d..3d599c0bc7 100644
--- a/mistralrs-pyo3/Cargo_template.toml
+++ b/mistralrs-pyo3/Cargo_template.toml
@@ -17,7 +17,7 @@ doc = false
[dependencies]
pyo3.workspace = true
-mistralrs-core = { version = "0.1.15", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
+mistralrs-core = { version = "0.1.16", path = "../mistralrs-core", features=["pyo3_macros","$feature_name"] }
serde.workspace = true
serde_json.workspace = true
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.5.0", features=["$feature_name"] }
diff --git a/mistralrs-pyo3/pyproject.toml b/mistralrs-pyo3/pyproject.toml
index 8b8339cfcf..361cb0783e 100644
--- a/mistralrs-pyo3/pyproject.toml
+++ b/mistralrs-pyo3/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "maturin"
[project]
name = "mistralrs"
-version = "0.1.15"
+version = "0.1.16"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
diff --git a/mistralrs-pyo3/pyproject_template.toml b/mistralrs-pyo3/pyproject_template.toml
index 2ec480ef4f..229fdaa9b8 100644
--- a/mistralrs-pyo3/pyproject_template.toml
+++ b/mistralrs-pyo3/pyproject_template.toml
@@ -4,7 +4,7 @@ build-backend = "maturin"
[project]
name = "$name"
-version = "0.1.15"
+version = "0.1.16"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml
index 7b8f06f275..60c720edb0 100644
--- a/mistralrs-server/Cargo.toml
+++ b/mistralrs-server/Cargo.toml
@@ -22,8 +22,7 @@ axum = { version = "0.7.4", features = ["tokio"] }
tower-http = { version = "0.5.1", features = ["cors"]}
utoipa = { version = "4.2", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "7.1.0", features = ["axum"]}
-mistralrs-core = { version = "0.1.15", path = "../mistralrs-core" }
-dyn-fmt = "0.4.0"
+mistralrs-core = { version = "0.1.16", path = "../mistralrs-core" }
indexmap.workspace = true
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
diff --git a/mistralrs-server/README.md b/mistralrs-server/README.md
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/mistralrs-vision/README.md b/mistralrs-vision/README.md
new file mode 100644
index 0000000000..ca482bb1a2
--- /dev/null
+++ b/mistralrs-vision/README.md
@@ -0,0 +1,5 @@
+# `mistralrs-vision`
+
+This crate provides vision utilities for mistral.rs inspired by torchvision.
+
+Documentation: https://ericlbuehler.github.io/mistral.rs/mistralrs_vision/index.html
\ No newline at end of file
diff --git a/mistralrs-vision/src/lib.rs b/mistralrs-vision/src/lib.rs
index 76dbf8b690..09073512dd 100644
--- a/mistralrs-vision/src/lib.rs
+++ b/mistralrs-vision/src/lib.rs
@@ -1,9 +1,32 @@
+//! This crate provides vision utilities for mistral.rs inspired by torchvision.
+//! In particular, it represents transformations on some `Self` type which are applied
+//! sequentially.
+//!
+//! ## Example
+//! ```rust
+//! use candle_core::Device;
+//! use image::{ColorType, DynamicImage};
+//! use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
+//!
+//! let image = DynamicImage::new(3, 4, ColorType::Rgb8);
+//! let transforms = Transforms {
+//! input: &ToTensor,
+//! inner_transforms: &[&Normalize {
+//! mean: vec![0.5, 0.5, 0.5],
+//! std: vec![0.5, 0.5, 0.5],
+//! }],
+//! };
+//! let transformed = image.apply(transforms, &Device::Cpu).unwrap();
+//! assert_eq!(transformed.dims(), &[3, 4, 3]);
+//! ```
+
use candle_core::{Device, Result, Tensor};
use image::DynamicImage;
mod transforms;
pub(crate) mod utils;
pub use transforms::{InterpolateResize, Normalize, ToTensor};
+/// A transform over an image. The input may vary but the output is always a Tensor.
pub trait ImageTransform {
type Input;
type Output;
@@ -11,11 +34,14 @@ pub trait ImageTransform {
fn map(&self, x: &Self::Input, device: &Device) -> Result;
}
+/// Transforms to apply, starting with the `input` and then with each transform in
+/// `inner_transforms` applied sequentially
pub struct Transforms<'a> {
pub input: &'a dyn ImageTransform,
pub inner_transforms: &'a [&'a dyn ImageTransform],
}
+/// Application of transforms to the Self type.
pub trait ApplyTransforms<'a> {
fn apply(&self, transforms: Transforms<'a>, device: &Device) -> Result;
}
diff --git a/mistralrs-vision/src/transforms.rs b/mistralrs-vision/src/transforms.rs
index f6c462a5a1..e75838be3a 100644
--- a/mistralrs-vision/src/transforms.rs
+++ b/mistralrs-vision/src/transforms.rs
@@ -72,8 +72,7 @@ impl ImageTransform for Normalize {
}
}
-/// Do what `ToTensor` does, but also resize the image without preserving
-/// aspect ratio.
+/// Resize the image via nearest interpolation.
pub struct InterpolateResize {
pub target_w: usize,
pub target_h: usize,
diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml
index 1e0206fd74..82cd14d68b 100644
--- a/mistralrs/Cargo.toml
+++ b/mistralrs/Cargo.toml
@@ -12,7 +12,7 @@ license.workspace = true
homepage.workspace = true
[dependencies]
-mistralrs-core = { version = "0.1.15", path = "../mistralrs-core" }
+mistralrs-core = { version = "0.1.16", path = "../mistralrs-core" }
anyhow.workspace = true
tokio.workspace = true
candle-core.workspace = true
diff --git a/mistralrs/src/lib.rs b/mistralrs/src/lib.rs
index 65e6f60055..26e1b708fa 100644
--- a/mistralrs/src/lib.rs
+++ b/mistralrs/src/lib.rs
@@ -1,2 +1,50 @@
+//! This crate provides an asynchronous, multithreaded API to `mistral.rs`.
+//!
+//! ## Example
+//! ```no_run
+//! use std::sync::Arc;
+//! use tokio::sync::mpsc::channel;
+//!
+//! use mistralrs::{
+//! Constraint, DeviceMapMetadata, MistralRs, MistralRsBuilder,
+//! NormalLoaderType, NormalRequest, Request, RequestMessage, Response,
+//! SamplingParams, SchedulerMethod, TokenSource,
+//! };
+//!
+//! fn setup() -> anyhow::Result> {
+//! // See the examples for how to load your model.
+//! todo!()
+//! }
+//!
+//! fn main() -> anyhow::Result<()> {
+//! let mistralrs = setup()?;
+//!
+//! let (tx, mut rx) = channel(10_000);
+//! let request = Request::Normal(NormalRequest {
+//! messages: RequestMessage::Completion {
+//! text: "Hello! My name is ".to_string(),
+//! echo_prompt: false,
+//! best_of: 1,
+//! },
+//! sampling_params: SamplingParams::default(),
+//! response: tx,
+//! return_logprobs: false,
+//! is_streaming: false,
+//! id: 0,
+//! constraint: Constraint::None,
+//! suffix: None,
+//! adapters: None,
+//! });
+//! mistralrs.get_sender().blocking_send(request)?;
+//!
+//! let response = rx.blocking_recv().unwrap();
+//! match response {
+//! Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text),
+//! _ => unreachable!(),
+//! }
+//! Ok(())
+//! }
+//! ```
+
pub use candle_core::{quantized::GgmlDType, DType, Device, Result};
pub use mistralrs_core::*;