Skip to content

Commit

Permalink
Add some examples
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jun 7, 2024
1 parent ced3cab commit cbccb41
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 12 deletions.
25 changes: 25 additions & 0 deletions examples/python/plain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture

runner = Runner(
which=Which.Plain(
model_id="mistralai/Mistral-7B-Instruct-v0.1",
tokenizer_json=None,
repeat_last_n=64,
arch=Architecture.Mistral,
),
)

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[
{"role": "user", "content": "Tell me a story about the Rust type system."}
],
max_tokens=256,
presence_penalty=1.0,
top_p=0.1,
temperature=0.1,
)
)
print(res.choices[0].message.content)
print(res.usage)
2 changes: 1 addition & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub use pipeline::{
MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Qwen2Loader, SpeculativeConfig,
SpeculativeLoader, SpeculativePipeline, TokenSource, VisionLoader, VisionLoaderBuilder,
VisionModelLoader, VisionSpecificConfig,
VisionLoaderType, VisionModelLoader, VisionSpecificConfig,
};
pub use request::{Constraint, Content, NormalRequest, Request, RequestMessage};
pub use response::Response;
Expand Down
9 changes: 0 additions & 9 deletions mistralrs-core/src/vision_models/phi3_inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,15 +379,6 @@ impl ImagePreProcessor for Phi3InputsProcessor {

let hd_image = Self::hd_transform(image, config.num_crops.expect("Need `num_crops`"));

let transforms_hd2 = Transforms {
input: &ToTensor,
inner_transforms: &[],
};

// (3,h,w)
let hd_image2 = hd_image.apply(transforms_hd2, device)?;
dbg!(hd_image2);

// Both hd and global have a normalization
// Transforms for the HD image
let transforms_hd = Transforms {
Expand Down
9 changes: 8 additions & 1 deletion mistralrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ anyhow.workspace = true
tokio.workspace = true
candle-core.workspace = true
serde_json.workspace = true
image.workspace = true
indexmap.workspace = true
either.workspace = true

[features]
cuda = ["mistralrs-core/cuda"]
Expand Down Expand Up @@ -56,4 +59,8 @@ required-features = []

[[example]]
name = "gguf_locally"
required-features = []
required-features = []

[[example]]
name = "phi3v"
required-features = []
2 changes: 1 addition & 1 deletion mistralrs/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mistral.rs Rust API: `mistralrs`
[![Documentation](https://github.com/EricLBuehler/mistral.rs/actions/workflows/docs.yml/badge.svg)](https://ericlbuehler.github.io/mistral.rs/mistralrs/)

Mistral.rs provides a convenient Rust multithreaded API. To install, add `mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }` to the Cargo.toml file.
Mistral.rs provides a convenient Rust multithreaded/async API. To install, add `mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }` to the Cargo.toml file.

Examples can be found [here](examples).
71 changes: 71 additions & 0 deletions mistralrs/examples/phi3v/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use either::Either;
use image::{ColorType, DynamicImage};
use indexmap::IndexMap;
use std::sync::Arc;
use tokio::sync::mpsc::channel;

use mistralrs::{
Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalRequest, Request,
RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, VisionLoaderBuilder,
VisionLoaderType, VisionSpecificConfig,
};

fn setup() -> anyhow::Result<Arc<MistralRs>> {
// Select a Mistral model
let loader = VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn: false,
repeat_last_n: 64,
},
None,
None,
Some("microsoft/Phi-3-vision-128k-instruct".to_string()),
)
.build(VisionLoaderType::Phi3V);
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
TokenSource::CacheToken,
None,
&Device::cuda_if_available(0)?,
false,
DeviceMapMetadata::dummy(),
None,
)?;
// Create the MistralRs, which is a runner
Ok(MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build())
}

fn main() -> anyhow::Result<()> {
let mistralrs = setup()?;

let (tx, mut rx) = channel(10_000);
let request = Request::Normal(NormalRequest {
messages: RequestMessage::VisionChat {
images: vec![DynamicImage::new(1280, 720, ColorType::Rgb8)],
messages: vec![IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
(
"content".to_string(),
Either::Left("<|image_1|>\nWhat is shown in this image?".to_string()),
),
])],
},
sampling_params: SamplingParams::default(),
response: tx,
return_logprobs: false,
is_streaming: false,
id: 0,
constraint: Constraint::None,
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;

let response = rx.blocking_recv().unwrap();
match response {
Response::Done(c) => println!("Text: {}", c.choices[0].message.content),
_ => unreachable!(),
}
Ok(())
}

0 comments on commit cbccb41

Please sign in to comment.