diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 2f588ca0d..ffdb1375d 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -256,15 +256,32 @@ pub fn create_request( tools: &[Tool], image_format: &ImageFormat, ) -> anyhow::Result { - if model_config.model_name.starts_with("o1-mini") { + let is_o1 = model_config.model_name.starts_with("o1"); + let is_o3 = model_config.model_name.starts_with("o3"); + + // Only extract reasoning effort for O1/O3 models + let (model_name, reasoning_effort) = if is_o1 || is_o3 { + let parts: Vec<&str> = model_config.model_name.split('-').collect(); + let last_part = parts.last().unwrap(); + + match *last_part { + "low" | "medium" | "high" => { + let base_name = parts[..parts.len()-1].join("-"); + (base_name, Some(last_part.to_string())) + }, + _ => (model_config.model_name.to_string(), Some("medium".to_string())) + } + } else { + // For non-O family models, use the model name as is and no reasoning effort + (model_config.model_name.to_string(), None) + }; + + if model_name.starts_with("o1-mini") { return Err(anyhow!( "o1-mini model is not currently supported since Goose uses tool calling and o1-mini does not support it. Please use o1 or o3 models instead." )); } - let is_o1 = model_config.model_name.starts_with("o1"); - let is_o3 = model_config.model_name.starts_with("o3"); - let system_message = json!({ // NOTE: per OPENAI docs , With O1 and newer models, `developer` // should replace `system` role . @@ -284,16 +301,26 @@ pub fn create_request( messages_array.extend(messages_spec); let mut payload = json!({ - "model": model_config.model_name, + "model": model_name, "messages": messages_array }); + // NOTE: add resoning effort if present + // e.g if the user chooses `o3-mini-high` as their model name + // then it will set `reasoning_effort` to `high`. + // Defaults to medium per openai docs + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort + if let Some(effort) = reasoning_effort { + payload.as_object_mut().unwrap() + .insert("reasoning_effort".to_string(), json!(effort)); + } + + // Add tools if present if !tools_spec.is_empty() { - payload - .as_object_mut() - .unwrap() + payload.as_object_mut().unwrap() .insert("tools".to_string(), json!(tools_spec)); } + // o1, o3 models currently don't support temperature if !is_o1 && !is_o3 { if let Some(temp) = model_config.temperature { @@ -316,26 +343,7 @@ pub fn create_request( .unwrap() .insert(key.to_string(), json!(tokens)); } - // NOTE: add resoning effort if present - // e.g if the user chooses `o3-mini-high` as their model name - // then it will set `reasoning_effort` to `high`. - // Defaults to medium per openai docs - // https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort - if is_o1 || is_o3 { - let mut reasoning_effort = "medium"; - // Extract the last part of model name using '-' as delimiter - if let Some(last_part) = model_config.model_name.split('-').last() { - // Check if it's a valid reasoning effort value - match last_part { - "low" | "medium" | "high" => reasoning_effort = last_part, - _ => {} // Keep default "medium" if not a valid value - } - } - payload - .as_object_mut() - .unwrap() - .insert("reasoning_effort".to_string(), json!(reasoning_effort)); - } + Ok(payload) } @@ -365,6 +373,86 @@ mod tests { } }"#; + const EPSILON: f64 = 1e-6; // More lenient epsilon for float comparison + + // Test utilities + struct TestModelConfig { + model_name: String, + tokenizer_name: String, + temperature: Option, + max_tokens: Option, + } + + impl TestModelConfig { + fn new(model_name: &str, tokenizer_name: &str) -> Self { + Self { + model_name: model_name.to_string(), + tokenizer_name: tokenizer_name.to_string(), + temperature: Some(0.7), + max_tokens: Some(1024), + } + } + + fn without_temperature(mut self) -> Self { + self.temperature = None; + self + } + + fn to_model_config(&self) -> ModelConfig { + ModelConfig { + model_name: self.model_name.clone(), + tokenizer_name: self.tokenizer_name.clone(), + context_limit: Some(4096), + temperature: self.temperature, + max_tokens: self.max_tokens, + } + } + } + + fn assert_request( + model_config: &TestModelConfig, + expected_model: &str, + expected_reasoning: Option<&str>, + expect_max_completion_tokens: bool, + ) -> anyhow::Result<()> { + let request = create_request( + &model_config.to_model_config(), + "system", + &[], + &[], + &ImageFormat::OpenAi, + )?; + let obj = request.as_object().unwrap(); + + // Check model name + assert_eq!(obj.get("model").unwrap(), expected_model); + + // Check reasoning effort + match expected_reasoning { + Some(effort) => assert_eq!(obj.get("reasoning_effort").unwrap(), effort), + None => assert!(obj.get("reasoning_effort").is_none()), + } + + // Check max tokens field + if expect_max_completion_tokens { + assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); + assert!(obj.get("max_tokens").is_none()); + } else { + assert!(obj.get("max_completion_tokens").is_none()); + assert_eq!(obj.get("max_tokens").unwrap(), 1024); + } + + // Check temperature if present + if let Some(expected_temp) = model_config.temperature { + let temp = obj.get("temperature").unwrap().as_f64().unwrap(); + assert!((temp - f64::from(expected_temp)).abs() < EPSILON); + } else { + assert!(obj.get("temperature").is_none()); + } + + Ok(()) + } + #[test] fn test_format_messages() -> anyhow::Result<()> { let message = Message::user().with_text("Hello"); @@ -617,11 +705,12 @@ mod tests { }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); + assert_eq!(obj.get("model").unwrap(), "o3-mini"); assert_eq!(obj.get("reasoning_effort").unwrap(), "medium"); assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); assert!(obj.get("max_tokens").is_none()); - // Test custom reasoning effort for O3 model with valid suffix + // Test custom reasoning effort for O3 model let model_config = ModelConfig { model_name: "o3-mini-high".to_string(), tokenizer_name: "o3-mini".to_string(), @@ -631,11 +720,12 @@ mod tests { }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); + assert_eq!(obj.get("model").unwrap(), "o3-mini"); assert_eq!(obj.get("reasoning_effort").unwrap(), "high"); assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); assert!(obj.get("max_tokens").is_none()); - // Test invalid reasoning effort defaults to medium + // Test invalid suffix defaults to medium let model_config = ModelConfig { model_name: "o3-mini-invalid".to_string(), tokenizer_name: "o3-mini".to_string(), @@ -645,6 +735,7 @@ mod tests { }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); + assert_eq!(obj.get("model").unwrap(), "o3-mini-invalid"); assert_eq!(obj.get("reasoning_effort").unwrap(), "medium"); assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); assert!(obj.get("max_tokens").is_none()); @@ -664,11 +755,12 @@ mod tests { }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); + assert_eq!(obj.get("model").unwrap(), "o1"); assert_eq!(obj.get("reasoning_effort").unwrap(), "medium"); assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); assert!(obj.get("max_tokens").is_none()); - // Test custom reasoning effort for O1 model with valid suffix + // Test custom reasoning effort for O1 model let model_config = ModelConfig { model_name: "o1-low".to_string(), tokenizer_name: "o1".to_string(), @@ -678,6 +770,7 @@ mod tests { }; let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; let obj = request.as_object().unwrap(); + assert_eq!(obj.get("model").unwrap(), "o1"); assert_eq!(obj.get("reasoning_effort").unwrap(), "low"); assert_eq!(obj.get("max_completion_tokens").unwrap(), 1024); assert!(obj.get("max_tokens").is_none()); @@ -686,37 +779,127 @@ mod tests { } #[test] - fn test_create_request_o1_mini_not_supported() -> anyhow::Result<()> { - // Test o1-mini is not supported - let model_config = ModelConfig { - model_name: "o1-mini".to_string(), - tokenizer_name: "o1-mini".to_string(), - context_limit: Some(4096), - temperature: None, - max_tokens: Some(1024), - }; - let result = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi); + fn test_o3_default_reasoning_effort() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("o3-mini", "o3-mini").without_temperature(), + "o3-mini", + Some("medium"), + true, + ) + } + + #[test] + fn test_o3_custom_reasoning_effort() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("o3-mini-high", "o3-mini").without_temperature(), + "o3-mini", + Some("high"), + true, + ) + } + + #[test] + fn test_o3_invalid_suffix_defaults_to_medium() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("o3-mini-invalid", "o3-mini").without_temperature(), + "o3-mini-invalid", + Some("medium"), + true, + ) + } + + #[test] + fn test_o1_default_reasoning_effort() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("o1", "o1").without_temperature(), + "o1", + Some("medium"), + true, + ) + } + + #[test] + fn test_o1_custom_reasoning_effort() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("o1-low", "o1").without_temperature(), + "o1", + Some("low"), + true, + ) + } + + #[test] + fn test_o1_mini_not_supported() -> anyhow::Result<()> { + let config = TestModelConfig::new("o1-mini", "o1-mini").without_temperature(); + let result = create_request( + &config.to_model_config(), + "system", + &[], + &[], + &ImageFormat::OpenAi, + ); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("o1-mini model is not currently supported")); Ok(()) } #[test] - fn test_create_request_non_o_family() -> anyhow::Result<()> { - // Test non-O1/O3 model has no reasoning effort and uses max_tokens - let model_config = ModelConfig { - model_name: "gpt-4".to_string(), - tokenizer_name: "gpt-4".to_string(), - context_limit: Some(4096), - temperature: Some(0.7), - max_tokens: Some(1024), - }; - let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; - let obj = request.as_object().unwrap(); - assert!(obj.get("reasoning_effort").is_none()); - assert!(obj.get("max_completion_tokens").is_none()); - assert_eq!(obj.get("max_tokens").unwrap(), 1024); + fn test_gpt4_standard_config() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-4", "gpt-4"), + "gpt-4", + None, + false, + ) + } - Ok(()) + #[test] + fn test_gpt4_with_version_suffix() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-4-0314", "gpt-4"), + "gpt-4-0314", + None, + false, + ) + } + + #[test] + fn test_gpt35_turbo_config() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-3.5-turbo", "gpt-3.5-turbo"), + "gpt-3.5-turbo", + None, + false, + ) + } + + #[test] + fn test_non_o_family_with_high_suffix() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-4-high-performance", "gpt-4"), + "gpt-4-high-performance", + None, + false, + ) + } + + #[test] + fn test_non_o_family_with_low_suffix() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-4-low-latency", "gpt-4"), + "gpt-4-low-latency", + None, + false, + ) + } + + #[test] + fn test_non_o_family_with_medium_suffix() -> anyhow::Result<()> { + assert_request( + &TestModelConfig::new("gpt-4-medium", "gpt-4"), + "gpt-4-medium", + None, + false, + ) } }