From d61f5f20ad198b968936fdfab8f386c37810fdb0 Mon Sep 17 00:00:00 2001 From: hbc Date: Sun, 19 Mar 2023 15:15:20 -0700 Subject: [PATCH 1/2] feat: record inference stats --- llama-cli/src/main.rs | 4 +++- llama-rs/src/lib.rs | 45 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index cb3e6cc5..a8bfdd23 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -169,7 +169,9 @@ fn main() { println!(); match res { - Ok(_) => (), + Ok(stats) => { + println!("{}", stats); + } Err(llama_rs::InferenceError::ContextFull) => { log::warn!("Context window full, stopping inference.") } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index f0d58689..2c24571d 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -5,6 +5,7 @@ use std::{ fmt::Display, io::{BufRead, Read, Seek, SeekFrom}, path::{Path, PathBuf}, + time, }; use thiserror::Error; @@ -107,6 +108,38 @@ impl Default for InferenceParameters { } } +pub struct InferenceStats { + pub feed_prompt_duration: std::time::Duration, + pub prompt_tokens: usize, + pub predict_duration: std::time::Duration, + pub predict_tokens: usize, +} + +impl Default for InferenceStats { + fn default() -> Self { + Self { + feed_prompt_duration: std::time::Duration::from_secs(0), + prompt_tokens: 0, + predict_duration: std::time::Duration::from_secs(0), + predict_tokens: 0, + } + } +} + +impl Display for InferenceStats { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "feed_prompt_duration: {}ms, prompt_tokens: {}, predict_duration: {}ms, predict_tokens: {}, per_token_duration: {:.3}ms", + self.feed_prompt_duration.as_millis(), + self.prompt_tokens, + self.predict_duration.as_millis(), + self.predict_tokens, + (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64), + ) + } +} + type TokenId = i32; type Token = String; @@ -1236,10 +1269,16 @@ impl InferenceSession { maximum_token_count: Option, rng: &mut impl rand::Rng, callback: impl Fn(OutputToken) -> Result<(), E>, - ) -> Result<(), InferenceError> { + ) -> Result { + let mut stats = InferenceStats::default(); + + let start_at = time::SystemTime::now(); + // Feed the initial prompt through the transformer, to update its // context window with new data. self.feed_prompt(model, vocab, params, prompt, |tk| callback(tk))?; + stats.feed_prompt_duration = start_at.elapsed().unwrap(); + stats.prompt_tokens = self.n_past; // After the prompt is consumed, sample tokens by repeatedly calling // `infer_next_token`. We generate tokens until the model returns an @@ -1261,8 +1300,10 @@ impl InferenceSession { break; } } + stats.predict_duration = start_at.elapsed().unwrap(); + stats.predict_tokens = self.n_past; - Ok(()) + Ok(stats) } /// Obtains a serializable snapshot of the current inference status. This From 3ff2bf7fdd214e098bb613a34b981ad4cea7ab19 Mon Sep 17 00:00:00 2001 From: hbc Date: Mon, 20 Mar 2023 13:35:46 -0700 Subject: [PATCH 2/2] chore: split by \n --- llama-rs/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 2c24571d..4c1473ff 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -130,7 +130,7 @@ impl Display for InferenceStats { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!( f, - "feed_prompt_duration: {}ms, prompt_tokens: {}, predict_duration: {}ms, predict_tokens: {}, per_token_duration: {:.3}ms", + "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms", self.feed_prompt_duration.as_millis(), self.prompt_tokens, self.predict_duration.as_millis(),