diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 77405fbf6e..65f4f3ed58 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -49,7 +49,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } Err(e) => { output.push(json!({ - "role": "tool", + "role": "user", "content": format!("Error: {}", e), "tool_call_id": request.id })); @@ -104,7 +104,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< // First add the tool response with all content output.push(json!({ - "role": "tool", + "role": "user", "content": tool_response_content, "tool_call_id": response.id })); @@ -114,7 +114,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< Err(e) => { // A tool result error is shown as output so the model can interpret the error message output.push(json!({ - "role": "tool", + "role": "user", "content": format!("The tool call returned the following error:\n{}", e), "tool_call_id": response.id })); diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index ae94a5e035..d659c314eb 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -7,12 +7,141 @@ use std::time::Duration; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat}; -use crate::message::Message; +use crate::message::{Message, MessageContent, ToolRequest}; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; +use mcp_core::{content::TextContent, tool::ToolCall, role::Role}; use mcp_core::tool::Tool; use url::Url; +// Helper function to create a message with text content +fn create_text_message(text: String) -> Message { + let mut msg = Message::assistant(); + msg.content = vec![MessageContent::Text(TextContent { + text, + annotations: None, + })]; + msg +} + +// Helper function to create a message with a general tool request +fn create_tool_message(tool_name: String, args: Value) -> Message { + let mut msg = Message::assistant(); + msg.content = vec![MessageContent::ToolRequest(ToolRequest { + id: "1".to_string(), // Possibly refine if multiple calls are needed + tool_call: Ok(ToolCall { + name: tool_name, + arguments: args, + }), + })]; + msg +} + +/// Attempts to parse multiple tool usages of the form: +/// +/// valueA +/// valueB +/// ... +/// +/// +/// ... +/// +/// +/// Returns a Vec, each containing a tool call. +fn parse_tool_usages(content: &str) -> Vec { + let mut messages = Vec::new(); + let mut search_start = 0; + + // First normalize newlines to spaces to handle multi-line format + let content = content.replace('\n', " "); + + while let Some(start_idx) = content[search_start..].find('<') { + // Adjust to absolute index + let start_idx = start_idx + search_start; + let after_lt = &content[start_idx + 1..]; + // Find '>' to extract the tool name + let Some(end_tool_name_idx) = after_lt.find('>') else { + break; + }; + let tool_name = after_lt[..end_tool_name_idx].trim(); + if tool_name.is_empty() { + break; + } + + println!("Found tool: {}", tool_name); // Debug trace + + let closing_tag = format!("", tool_name); + let after_tool_start = &after_lt[end_tool_name_idx + 1..]; + let Some(closing_idx) = after_tool_start.find(&closing_tag) else { + break; + }; + + let inner_content = &after_tool_start[..closing_idx]; + let mut args = json!({}); + let mut param_search_start = 0; + + // Parse value + while let Some(param_open_idx) = inner_content[param_search_start..].find('<') { + let param_open_idx = param_open_idx + param_search_start; + let after_param_lt = &inner_content[param_open_idx + 1..]; + if let Some(param_close_idx) = after_param_lt.find('>') { + let param_name = after_param_lt[..param_close_idx].trim(); + if param_name.is_empty() { + break; + } + let param_closing_tag = format!("", param_name); + let after_param_start = &after_param_lt[param_close_idx + 1..]; + if let Some(param_closing_idx) = after_param_start.find(¶m_closing_tag) { + let param_value = &after_param_start[..param_closing_idx].trim(); + println!(" Param: {} = {}", param_name, param_value); // Debug trace + args[param_name] = json!(param_value); + + param_search_start = param_open_idx + + 1 + + param_close_idx + + 1 + + param_closing_idx + + param_closing_tag.len(); + } else { + break; + } + } else { + break; + } + } + + // Build the tool message + messages.push(create_tool_message(tool_name.to_string(), args)); + + // Advance to beyond the closing tag + search_start = start_idx + 1 + end_tool_name_idx + 1 + closing_idx + closing_tag.len(); + } + + // Debug trace of parsed messages + println!("\n=== Parsed Tool Messages ==="); + for (i, msg) in messages.iter().enumerate() { + println!("\nMessage {}: ", i + 1); + match &msg.content[0] { + MessageContent::ToolRequest(tool_req) => { + if let Ok(tool_call) = &tool_req.tool_call { + println!(" Tool: {}", tool_call.name); + println!(" Args: {}", + serde_json::to_string_pretty(&tool_call.arguments) + .unwrap_or_else(|_| "Failed to format args".to_string()) + ); + } + }, + MessageContent::Text(text) => { + println!(" Text: {}", text.text); + }, + _ => println!(" Other content type"), + } + } + println!("========================\n"); + + messages +} + pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic"; @@ -140,6 +269,59 @@ fn create_request_based_on_model( messages: &[Message], tools: &[Tool], ) -> anyhow::Result { + // For deepseek models, we want to include tools in the system prompt instead + if model_config.model_name.contains("deepseek-r1") { + let tool_instructions = if !tools.is_empty() { + let tool_descriptions: Vec = tools.iter() + .map(|tool| format!("- {}: {}", tool.name, tool.description)) + .collect(); + + // println!("\nTools being provided:\n{}", tool_descriptions.join("\n")); + + format!( + "\n\nAvailable tools:\n{}\n\n# Reminder: Instructions for Tool Use\n\nTool uses are formatted using XML-style tags. The tool name is enclosed in opening and closing tags. Here's the structure:\n\n\nvalue1\nvalue2\n...\n\n\nFor example, to use the shell tool:\n\n\nls -l\n\n\nAlways adhere to this format for all tool uses to ensure proper parsing and execution.\n", + tool_descriptions.join("\n") + ) + } else { + String::new() + }; + + let enhanced_system = format!("{}{}", system, tool_instructions); + println!("\nEnhanced system prompt:\n{}", enhanced_system); + + // Find the last user message and enhance it + let mut modified_messages = messages.to_vec(); + if let Some(last_user_msg_idx) = modified_messages.iter().rposition(|msg| { + // Only consider user messages that don't have a tool_call_id + msg.role == Role::User && !msg.content.iter().any(|content| { + matches!(content, MessageContent::ToolResponse(_)) + }) + }) { + let last_user_msg = &modified_messages[last_user_msg_idx]; + // Get the text content from the last user message + let user_text = last_user_msg.content.iter().find_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.text.clone()) + } else { + None + } + }).unwrap_or_default(); + + // Create new message with enhanced system prompt prepended + let enhanced_msg = Message::user().with_text(format!("{}\n{}", enhanced_system, user_text)); + modified_messages[last_user_msg_idx] = enhanced_msg; + } + + let payload = create_request( + model_config, + "", // Empty system prompt since we included it in the user message + &modified_messages, + &[], // Pass empty tools array since we're handling them in the system prompt + &super::utils::ImageFormat::OpenAi, + )?; + return Ok(payload); + } + let mut payload = create_request( model_config, system, @@ -198,13 +380,38 @@ impl Provider for OpenRouterProvider { tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { // Create the base payload - let payload = create_request_based_on_model(&self.model, system, messages, tools)?; - + let mut payload = create_request_based_on_model(&self.model, system, messages, tools)?; + // payload["provider"] = json!({"order": ["Avian"], "allow_fallbacks": false}); + println!("Request Payload: {}\n", serde_json::to_string_pretty(&payload).unwrap()); // Make request let response = self.post(payload.clone()).await?; + + + // Parse response - special handling for deepseek models + let message = if self.model.model_name.contains("deepseek-r1") { + let content = response["choices"][0]["message"]["content"] + .as_str() + .unwrap_or_default() + .to_string(); + + println!("Response payload:\n{}", serde_json::to_string_pretty(&response).unwrap()); + println!("\nExtracted content:\n{}", content); + + // Attempt to parse multiple tool usability from the content + let calls = parse_tool_usages(&content); + + if calls.is_empty() { + // No tool calls found, treat entire content as text + create_text_message(content) + } else { + // For demonstration, return the FIRST tool call. + // If you want to handle multiple calls, see the parse_tool_usages doc. + calls[0].clone() + } + } else { + response_to_message(response.clone())? + }; - // Parse response - let message = response_to_message(response.clone())?; let usage = match get_usage(&response) { Ok(usage) => usage, Err(ProviderError::UsageError(e)) => {