Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(openai): add all openai models to Model enum #136

Merged
merged 2 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/llm-chain-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ async-openai = "0.10.3"
async-trait = "0.1.68"
llm-chain = { path = "../llm-chain", version = "0.11.1", default-features = false }
serde = { version = "1.0.163" }
tiktoken-rs = { version = "0.4.2", features = ["async-openai"] }
strum = "0.24"
strum_macros = "0.24"
thiserror = "1.0.40"
tiktoken-rs = { version = "0.4.2", features = ["async-openai"] }
tokio = "1.28.0"

[dev-dependencies]
Expand Down
138 changes: 110 additions & 28 deletions crates/llm-chain-openai/src/chatgpt/options.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,88 @@
use llm_chain::traits;
use serde::{Deserialize, Serialize};
use strum_macros::EnumString;

/// The `Model` enum represents the available ChatGPT models that you can use through the OpenAI API. These models have different capabilities and performance characteristics, allowing you to choose the one that best suits your needs.
/// The `Model` enum represents the available ChatGPT models that you can use through the OpenAI
/// API.
///
/// Currently, the available models are:
/// - `ChatGPT3_5Turbo`: A high-performance and versatile model that offers a great balance of speed, quality, and affordability.
/// - `GPT4`: A high-performance model that offers the best quality, but is slower and more expensive than the `ChatGPT3_5Turbo` model.
/// - `Other(String)`: A variant that allows you to specify a custom model name as a string, in case new models are introduced or you have access to specialized models.
/// These models have different capabilities and performance characteristics, allowing you to choose
/// the one that best suits your needs. See <https://platform.openai.com/docs/models> for more
/// information.
///
/// # Example
///
/// ```
/// use llm_chain_openai::chatgpt::Model;
///
/// let turbo_model = Model::ChatGPT3_5Turbo;
/// let turbo_model = Model::Gpt35Turbo;
/// let custom_model = Model::Other("your_custom_model_name".to_string());
/// ```
///
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, PartialEq, Eq)]
#[non_exhaustive]
pub enum Model {
ChatGPT3_5Turbo,
GPT4,
/// A high-performance and versatile model that offers a great balance of speed, quality, and
/// affordability.
#[default]
#[strum(
serialize = "gpt-3.5-turbo",
serialize = "gpt-35-turbo",
serialize = "gpt3.5",
serialize = "gpt35"
)]
Gpt35Turbo,

/// Snapshot of gpt-3.5-turbo from March 1st 2023. Unlike gpt-3.5-turbo, this model will not
/// receive updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-3.5-turbo-0301")]
Gpt35Turbo0301,

/// A high-performance model that offers the best quality, but is slower and more expensive than
/// the `ChatGPT3_5Turbo` model.
#[strum(serialize = "gpt-4", serialize = "gpt4")]
Gpt4,

/// Snapshot of gpt-4 from March 14th 2023. Unlike gpt-4, this model will not receive updates,
/// and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-0314")]
Gpt4_0314,

/// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated
/// with our latest model iteration.
#[strum(serialize = "gpt-4-32k")]
Gpt4_32k,

/// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive
/// updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-32k-0314")]
Gpt4_32k0314,

/// A variant that allows you to specify a custom model name as a string, in case new models
/// are introduced or you have access to specialized models.
#[strum(default)]
Other(String),
}

impl Default for Model {
fn default() -> Self {
Self::ChatGPT3_5Turbo
}
impl Model {
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt35Turbo` instead")]
#[allow(non_upper_case_globals)]
pub const ChatGPT3_5Turbo: Model = Model::Gpt35Turbo;
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt4` instead")]
pub const GPT4: Model = Model::Gpt4;
}

/// The `Model` enum implements the `ToString` trait, allowing you to easily convert it to a string.
impl ToString for Model {
fn to_string(&self) -> String {
match &self {
Self::ChatGPT3_5Turbo => "gpt-3.5-turbo".to_string(),
Self::GPT4 => "gpt-4".to_string(),
Self::Other(model) => model.to_string(),
}
}
}

/// The `Model` enum implements the `From<String>` trait, allowing you to easily convert a string to a `Model`.
impl From<String> for Model {
fn from(s: String) -> Self {
match s.as_str() {
"gpt-3.5-turbo" => Self::ChatGPT3_5Turbo,
"gpt-4" => Self::GPT4,
_ => Self::Other(s),
Model::Gpt35Turbo => "gpt-3.5-turbo".to_string(),
Model::Gpt4 => "gpt-4".to_string(),
Model::Gpt35Turbo0301 => "gpt-3.5-turbo-0301".to_string(),
Model::Gpt4_0314 => "gpt-4-0314".to_string(),
Model::Gpt4_32k => "gpt-4-32k".to_string(),
Model::Gpt4_32k0314 => "gpt-4-32k-0314".to_string(),
Model::Other(model) => model.to_string(),
}
}
}
Expand Down Expand Up @@ -80,3 +115,50 @@ pub struct PerExecutor {
}

impl traits::Options for PerExecutor {}

#[cfg(test)]
mod tests {
use std::str::FromStr;

use super::*;

// Tests for FromStr
#[test]
fn test_from_str() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(Model::from_str("gpt-3.5-turbo")?, Model::Gpt35Turbo);
assert_eq!(
Model::from_str("gpt-3.5-turbo-0301")?,
Model::Gpt35Turbo0301
);
assert_eq!(Model::from_str("gpt-4")?, Model::Gpt4);
assert_eq!(Model::from_str("gpt-4-0314")?, Model::Gpt4_0314);
assert_eq!(Model::from_str("gpt-4-32k")?, Model::Gpt4_32k);
assert_eq!(Model::from_str("gpt-4-32k-0314")?, Model::Gpt4_32k0314);
assert_eq!(
Model::from_str("custom_model")?,
Model::Other("custom_model".to_string())
);
Ok(())
}

// Test ToString
#[test]
fn test_to_string() {
assert_eq!(Model::Gpt35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::Gpt4.to_string(), "gpt-4");
assert_eq!(Model::Gpt35Turbo0301.to_string(), "gpt-3.5-turbo-0301");
assert_eq!(Model::Gpt4_0314.to_string(), "gpt-4-0314");
assert_eq!(Model::Gpt4_32k.to_string(), "gpt-4-32k");
assert_eq!(Model::Gpt4_32k0314.to_string(), "gpt-4-32k-0314");
assert_eq!(
Model::Other("custom_model".to_string()).to_string(),
"custom_model"
);
}

#[test]
fn test_to_string_deprecated() {
assert_eq!(Model::ChatGPT3_5Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::GPT4.to_string(), "gpt-4");
}
}
6 changes: 3 additions & 3 deletions crates/llm-chain-openai/src/chatgpt/text_splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mod tests {
let chunk_overlap = 0;

let splitter = OpenAITextSplitter {
model: crate::chatgpt::Model::ChatGPT3_5Turbo,
model: crate::chatgpt::Model::Gpt35Turbo,
};

let chunks = splitter.split_text(doc, max_tokens_per_chunk, chunk_overlap)?;
Expand All @@ -74,7 +74,7 @@ mod tests {
let chunk_overlap = 1;

let splitter = OpenAITextSplitter {
model: crate::chatgpt::Model::ChatGPT3_5Turbo,
model: crate::chatgpt::Model::Gpt35Turbo,
};

let chunks = splitter.split_text(doc, max_tokens_per_chunk, chunk_overlap)?;
Expand All @@ -100,7 +100,7 @@ mod tests {
let chunk_overlap = max_tokens_per_chunk;

let splitter = OpenAITextSplitter {
model: crate::chatgpt::Model::ChatGPT3_5Turbo,
model: crate::chatgpt::Model::Gpt35Turbo,
};

let chunks = splitter.split_text(doc, max_tokens_per_chunk, chunk_overlap)?;
Expand Down