-
Notifications
You must be signed in to change notification settings - Fork 340
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ced3cab
commit cbccb41
Showing
6 changed files
with
106 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from mistralrs import Runner, Which, ChatCompletionRequest, Architecture | ||
|
||
runner = Runner( | ||
which=Which.Plain( | ||
model_id="mistralai/Mistral-7B-Instruct-v0.1", | ||
tokenizer_json=None, | ||
repeat_last_n=64, | ||
arch=Architecture.Mistral, | ||
), | ||
) | ||
|
||
res = runner.send_chat_completion_request( | ||
ChatCompletionRequest( | ||
model="mistral", | ||
messages=[ | ||
{"role": "user", "content": "Tell me a story about the Rust type system."} | ||
], | ||
max_tokens=256, | ||
presence_penalty=1.0, | ||
top_p=0.1, | ||
temperature=0.1, | ||
) | ||
) | ||
print(res.choices[0].message.content) | ||
print(res.usage) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# mistral.rs Rust API: `mistralrs` | ||
[![Documentation](https://github.com/EricLBuehler/mistral.rs/actions/workflows/docs.yml/badge.svg)](https://ericlbuehler.github.io/mistral.rs/mistralrs/) | ||
|
||
Mistral.rs provides a convenient Rust multithreaded API. To install, add `mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }` to the Cargo.toml file. | ||
Mistral.rs provides a convenient Rust multithreaded/async API. To install, add `mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git" }` to the Cargo.toml file. | ||
|
||
Examples can be found [here](examples). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
use either::Either; | ||
use image::{ColorType, DynamicImage}; | ||
use indexmap::IndexMap; | ||
use std::sync::Arc; | ||
use tokio::sync::mpsc::channel; | ||
|
||
use mistralrs::{ | ||
Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalRequest, Request, | ||
RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, VisionLoaderBuilder, | ||
VisionLoaderType, VisionSpecificConfig, | ||
}; | ||
|
||
fn setup() -> anyhow::Result<Arc<MistralRs>> { | ||
// Select a Mistral model | ||
let loader = VisionLoaderBuilder::new( | ||
VisionSpecificConfig { | ||
use_flash_attn: false, | ||
repeat_last_n: 64, | ||
}, | ||
None, | ||
None, | ||
Some("microsoft/Phi-3-vision-128k-instruct".to_string()), | ||
) | ||
.build(VisionLoaderType::Phi3V); | ||
// Load, into a Pipeline | ||
let pipeline = loader.load_model_from_hf( | ||
None, | ||
TokenSource::CacheToken, | ||
None, | ||
&Device::cuda_if_available(0)?, | ||
false, | ||
DeviceMapMetadata::dummy(), | ||
None, | ||
)?; | ||
// Create the MistralRs, which is a runner | ||
Ok(MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build()) | ||
} | ||
|
||
fn main() -> anyhow::Result<()> { | ||
let mistralrs = setup()?; | ||
|
||
let (tx, mut rx) = channel(10_000); | ||
let request = Request::Normal(NormalRequest { | ||
messages: RequestMessage::VisionChat { | ||
images: vec![DynamicImage::new(1280, 720, ColorType::Rgb8)], | ||
messages: vec![IndexMap::from([ | ||
("role".to_string(), Either::Left("user".to_string())), | ||
( | ||
"content".to_string(), | ||
Either::Left("<|image_1|>\nWhat is shown in this image?".to_string()), | ||
), | ||
])], | ||
}, | ||
sampling_params: SamplingParams::default(), | ||
response: tx, | ||
return_logprobs: false, | ||
is_streaming: false, | ||
id: 0, | ||
constraint: Constraint::None, | ||
suffix: None, | ||
adapters: None, | ||
}); | ||
mistralrs.get_sender().blocking_send(request)?; | ||
|
||
let response = rx.blocking_recv().unwrap(); | ||
match response { | ||
Response::Done(c) => println!("Text: {}", c.choices[0].message.content), | ||
_ => unreachable!(), | ||
} | ||
Ok(()) | ||
} |