Skip to content

Commit

Permalink
refactor(openai): improve model name handling and test organization
Browse files Browse the repository at this point in the history
- Extract model name and reasoning effort handling into separate logic
- Split large test functions into smaller atomic units
- Add TestModelConfig struct and assert_request helper for test reuse
- Fix temperature type handling (f32 vs f64)
- Improve code organization and readability with clear sections
- Add comprehensive test coverage for all model name scenarios

Signed-off-by: da-moon <contact@havi.dev>
  • Loading branch information
da-moon committed Feb 3, 2025
1 parent a0cdd40 commit fc43a49
Showing 1 changed file with 239 additions and 56 deletions.
295 changes: 239 additions & 56 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,32 @@ pub fn create_request(
tools: &[Tool],
image_format: &ImageFormat,
) -> anyhow::Result<Value, Error> {
if model_config.model_name.starts_with("o1-mini") {
let is_o1 = model_config.model_name.starts_with("o1");
let is_o3 = model_config.model_name.starts_with("o3");

// Only extract reasoning effort for O1/O3 models
let (model_name, reasoning_effort) = if is_o1 || is_o3 {
let parts: Vec<&str> = model_config.model_name.split('-').collect();
let last_part = parts.last().unwrap();

match *last_part {
"low" | "medium" | "high" => {
let base_name = parts[..parts.len()-1].join("-");
(base_name, Some(last_part.to_string()))
},
_ => (model_config.model_name.to_string(), Some("medium".to_string()))
}
} else {
// For non-O family models, use the model name as is and no reasoning effort
(model_config.model_name.to_string(), None)
};

if model_name.starts_with("o1-mini") {
return Err(anyhow!(
"o1-mini model is not currently supported since Goose uses tool calling and o1-mini does not support it. Please use o1 or o3 models instead."
));
}

let is_o1 = model_config.model_name.starts_with("o1");
let is_o3 = model_config.model_name.starts_with("o3");

let system_message = json!({
// NOTE: per OPENAI docs , With O1 and newer models, `developer`
// should replace `system` role .
Expand All @@ -284,16 +301,26 @@ pub fn create_request(
messages_array.extend(messages_spec);

let mut payload = json!({
"model": model_config.model_name,
"model": model_name,
"messages": messages_array
});

// NOTE: add resoning effort if present
// e.g if the user chooses `o3-mini-high` as their model name
// then it will set `reasoning_effort` to `high`.
// Defaults to medium per openai docs
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort
if let Some(effort) = reasoning_effort {
payload.as_object_mut().unwrap()
.insert("reasoning_effort".to_string(), json!(effort));
}

// Add tools if present
if !tools_spec.is_empty() {
payload
.as_object_mut()
.unwrap()
payload.as_object_mut().unwrap()
.insert("tools".to_string(), json!(tools_spec));
}

// o1, o3 models currently don't support temperature
if !is_o1 && !is_o3 {
if let Some(temp) = model_config.temperature {
Expand All @@ -316,26 +343,7 @@ pub fn create_request(
.unwrap()
.insert(key.to_string(), json!(tokens));
}
// NOTE: add resoning effort if present
// e.g if the user chooses `o3-mini-high` as their model name
// then it will set `reasoning_effort` to `high`.
// Defaults to medium per openai docs
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort
if is_o1 || is_o3 {
let mut reasoning_effort = "medium";
// Extract the last part of model name using '-' as delimiter
if let Some(last_part) = model_config.model_name.split('-').last() {
// Check if it's a valid reasoning effort value
match last_part {
"low" | "medium" | "high" => reasoning_effort = last_part,
_ => {} // Keep default "medium" if not a valid value
}
}
payload
.as_object_mut()
.unwrap()
.insert("reasoning_effort".to_string(), json!(reasoning_effort));
}

Ok(payload)
}

Expand Down Expand Up @@ -365,6 +373,86 @@ mod tests {
}
}"#;

const EPSILON: f64 = 1e-6; // More lenient epsilon for float comparison

// Test utilities
struct TestModelConfig {
model_name: String,
tokenizer_name: String,
temperature: Option<f32>,
max_tokens: Option<i32>,
}

impl TestModelConfig {
fn new(model_name: &str, tokenizer_name: &str) -> Self {
Self {
model_name: model_name.to_string(),
tokenizer_name: tokenizer_name.to_string(),
temperature: Some(0.7),
max_tokens: Some(1024),
}
}

fn without_temperature(mut self) -> Self {
self.temperature = None;
self
}

fn to_model_config(&self) -> ModelConfig {
ModelConfig {
model_name: self.model_name.clone(),
tokenizer_name: self.tokenizer_name.clone(),
context_limit: Some(4096),
temperature: self.temperature,
max_tokens: self.max_tokens,
}
}
}

fn assert_request(
model_config: &TestModelConfig,
expected_model: &str,
expected_reasoning: Option<&str>,
expect_max_completion_tokens: bool,
) -> anyhow::Result<()> {
let request = create_request(
&model_config.to_model_config(),
"system",
&[],
&[],
&ImageFormat::OpenAi,
)?;
let obj = request.as_object().unwrap();

// Check model name
assert_eq!(obj.get("model").unwrap(), expected_model);

// Check reasoning effort
match expected_reasoning {
Some(effort) => assert_eq!(obj.get("reasoning_effort").unwrap(), effort),
None => assert!(obj.get("reasoning_effort").is_none()),
}

// Check max tokens field
if expect_max_completion_tokens {
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());
} else {
assert!(obj.get("max_completion_tokens").is_none());
assert_eq!(obj.get("max_tokens").unwrap(), 1024);
}

// Check temperature if present
if let Some(expected_temp) = model_config.temperature {
let temp = obj.get("temperature").unwrap().as_f64().unwrap();
assert!((temp - f64::from(expected_temp)).abs() < EPSILON);
} else {
assert!(obj.get("temperature").is_none());
}

Ok(())
}

#[test]
fn test_format_messages() -> anyhow::Result<()> {
let message = Message::user().with_text("Hello");
Expand Down Expand Up @@ -617,11 +705,12 @@ mod tests {
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert_eq!(obj.get("model").unwrap(), "o3-mini");
assert_eq!(obj.get("reasoning_effort").unwrap(), "medium");
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());

// Test custom reasoning effort for O3 model with valid suffix
// Test custom reasoning effort for O3 model
let model_config = ModelConfig {
model_name: "o3-mini-high".to_string(),
tokenizer_name: "o3-mini".to_string(),
Expand All @@ -631,11 +720,12 @@ mod tests {
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert_eq!(obj.get("model").unwrap(), "o3-mini");
assert_eq!(obj.get("reasoning_effort").unwrap(), "high");
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());

// Test invalid reasoning effort defaults to medium
// Test invalid suffix defaults to medium
let model_config = ModelConfig {
model_name: "o3-mini-invalid".to_string(),
tokenizer_name: "o3-mini".to_string(),
Expand All @@ -645,6 +735,7 @@ mod tests {
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert_eq!(obj.get("model").unwrap(), "o3-mini-invalid");
assert_eq!(obj.get("reasoning_effort").unwrap(), "medium");
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());
Expand All @@ -664,11 +755,12 @@ mod tests {
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert_eq!(obj.get("model").unwrap(), "o1");
assert_eq!(obj.get("reasoning_effort").unwrap(), "medium");
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());

// Test custom reasoning effort for O1 model with valid suffix
// Test custom reasoning effort for O1 model
let model_config = ModelConfig {
model_name: "o1-low".to_string(),
tokenizer_name: "o1".to_string(),
Expand All @@ -678,6 +770,7 @@ mod tests {
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert_eq!(obj.get("model").unwrap(), "o1");
assert_eq!(obj.get("reasoning_effort").unwrap(), "low");
assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024);
assert!(obj.get("max_tokens").is_none());
Expand All @@ -686,37 +779,127 @@ mod tests {
}

#[test]
fn test_create_request_o1_mini_not_supported() -> anyhow::Result<()> {
// Test o1-mini is not supported
let model_config = ModelConfig {
model_name: "o1-mini".to_string(),
tokenizer_name: "o1-mini".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let result = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi);
fn test_o3_default_reasoning_effort() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("o3-mini", "o3-mini").without_temperature(),
"o3-mini",
Some("medium"),
true,
)
}

#[test]
fn test_o3_custom_reasoning_effort() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("o3-mini-high", "o3-mini").without_temperature(),
"o3-mini",
Some("high"),
true,
)
}

#[test]
fn test_o3_invalid_suffix_defaults_to_medium() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("o3-mini-invalid", "o3-mini").without_temperature(),
"o3-mini-invalid",
Some("medium"),
true,
)
}

#[test]
fn test_o1_default_reasoning_effort() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("o1", "o1").without_temperature(),
"o1",
Some("medium"),
true,
)
}

#[test]
fn test_o1_custom_reasoning_effort() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("o1-low", "o1").without_temperature(),
"o1",
Some("low"),
true,
)
}

#[test]
fn test_o1_mini_not_supported() -> anyhow::Result<()> {
let config = TestModelConfig::new("o1-mini", "o1-mini").without_temperature();
let result = create_request(
&config.to_model_config(),
"system",
&[],
&[],
&ImageFormat::OpenAi,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("o1-mini model is not currently supported"));
Ok(())
}

#[test]
fn test_create_request_non_o_family() -> anyhow::Result<()> {
// Test non-O1/O3 model has no reasoning effort and uses max_tokens
let model_config = ModelConfig {
model_name: "gpt-4".to_string(),
tokenizer_name: "gpt-4".to_string(),
context_limit: Some(4096),
temperature: Some(0.7),
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
assert!(obj.get("reasoning_effort").is_none());
assert!(obj.get("max_completion_tokens").is_none());
assert_eq!(obj.get("max_tokens").unwrap(), 1024);
fn test_gpt4_standard_config() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-4", "gpt-4"),
"gpt-4",
None,
false,
)
}

Ok(())
#[test]
fn test_gpt4_with_version_suffix() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-4-0314", "gpt-4"),
"gpt-4-0314",
None,
false,
)
}

#[test]
fn test_gpt35_turbo_config() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-3.5-turbo", "gpt-3.5-turbo"),
"gpt-3.5-turbo",
None,
false,
)
}

#[test]
fn test_non_o_family_with_high_suffix() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-4-high-performance", "gpt-4"),
"gpt-4-high-performance",
None,
false,
)
}

#[test]
fn test_non_o_family_with_low_suffix() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-4-low-latency", "gpt-4"),
"gpt-4-low-latency",
None,
false,
)
}

#[test]
fn test_non_o_family_with_medium_suffix() -> anyhow::Result<()> {
assert_request(
&TestModelConfig::new("gpt-4-medium", "gpt-4"),
"gpt-4-medium",
None,
false,
)
}
}

0 comments on commit fc43a49

Please sign in to comment.