Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tool calling emulation (experiment/WIP) #1016

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}
Err(e) => {
output.push(json!({
"role": "tool",
"role": "user",
"content": format!("Error: {}", e),
"tool_call_id": request.id
}));
Expand Down Expand Up @@ -104,7 +104,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<

// First add the tool response with all content
output.push(json!({
"role": "tool",
"role": "user",
"content": tool_response_content,
"tool_call_id": response.id
}));
Expand All @@ -114,7 +114,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
Err(e) => {
// A tool result error is shown as output so the model can interpret the error message
output.push(json!({
"role": "tool",
"role": "user",
"content": format!("The tool call returned the following error:\n{}", e),
"tool_call_id": response.id
}));
Expand Down
217 changes: 212 additions & 5 deletions crates/goose/src/providers/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,141 @@ use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat};
use crate::message::Message;
use crate::message::{Message, MessageContent, ToolRequest};
use crate::model::ModelConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
use mcp_core::{content::TextContent, tool::ToolCall, role::Role};
use mcp_core::tool::Tool;
use url::Url;

// Helper function to create a message with text content
fn create_text_message(text: String) -> Message {
let mut msg = Message::assistant();
msg.content = vec![MessageContent::Text(TextContent {
text,
annotations: None,
})];
msg
}

// Helper function to create a message with a general tool request
fn create_tool_message(tool_name: String, args: Value) -> Message {
let mut msg = Message::assistant();
msg.content = vec![MessageContent::ToolRequest(ToolRequest {
id: "1".to_string(), // Possibly refine if multiple calls are needed
tool_call: Ok(ToolCall {
name: tool_name,
arguments: args,
}),
})];
msg
}

/// Attempts to parse multiple tool usages of the form:
/// <tool_name>
/// <paramA>valueA</paramA>
/// <paramB>valueB</paramB>
/// ...
/// </tool_name>
/// <another_tool>
/// ...
/// </another_tool>
///
/// Returns a Vec<Message>, each containing a tool call.
fn parse_tool_usages(content: &str) -> Vec<Message> {
let mut messages = Vec::new();
let mut search_start = 0;

// First normalize newlines to spaces to handle multi-line format
let content = content.replace('\n', " ");

while let Some(start_idx) = content[search_start..].find('<') {
// Adjust to absolute index
let start_idx = start_idx + search_start;
let after_lt = &content[start_idx + 1..];
// Find '>' to extract the tool name
let Some(end_tool_name_idx) = after_lt.find('>') else {
break;
};
let tool_name = after_lt[..end_tool_name_idx].trim();
if tool_name.is_empty() {
break;
}

println!("Found tool: {}", tool_name); // Debug trace

let closing_tag = format!("</{}>", tool_name);
let after_tool_start = &after_lt[end_tool_name_idx + 1..];
let Some(closing_idx) = after_tool_start.find(&closing_tag) else {
break;
};

let inner_content = &after_tool_start[..closing_idx];
let mut args = json!({});
let mut param_search_start = 0;

// Parse <paramName>value</paramName>
while let Some(param_open_idx) = inner_content[param_search_start..].find('<') {
let param_open_idx = param_open_idx + param_search_start;
let after_param_lt = &inner_content[param_open_idx + 1..];
if let Some(param_close_idx) = after_param_lt.find('>') {
let param_name = after_param_lt[..param_close_idx].trim();
if param_name.is_empty() {
break;
}
let param_closing_tag = format!("</{}>", param_name);
let after_param_start = &after_param_lt[param_close_idx + 1..];
if let Some(param_closing_idx) = after_param_start.find(&param_closing_tag) {
let param_value = &after_param_start[..param_closing_idx].trim();
println!(" Param: {} = {}", param_name, param_value); // Debug trace
args[param_name] = json!(param_value);

param_search_start = param_open_idx
+ 1
+ param_close_idx
+ 1
+ param_closing_idx
+ param_closing_tag.len();
} else {
break;
}
} else {
break;
}
}

// Build the tool message
messages.push(create_tool_message(tool_name.to_string(), args));

// Advance to beyond the closing tag
search_start = start_idx + 1 + end_tool_name_idx + 1 + closing_idx + closing_tag.len();
}

// Debug trace of parsed messages
println!("\n=== Parsed Tool Messages ===");
for (i, msg) in messages.iter().enumerate() {
println!("\nMessage {}: ", i + 1);
match &msg.content[0] {
MessageContent::ToolRequest(tool_req) => {
if let Ok(tool_call) = &tool_req.tool_call {
println!(" Tool: {}", tool_call.name);
println!(" Args: {}",
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_else(|_| "Failed to format args".to_string())
);
}
},
MessageContent::Text(text) => {
println!(" Text: {}", text.text);
},
_ => println!(" Other content type"),
}
}
println!("========================\n");

messages
}

pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";
pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";

Expand Down Expand Up @@ -140,6 +269,59 @@ fn create_request_based_on_model(
messages: &[Message],
tools: &[Tool],
) -> anyhow::Result<Value, Error> {
// For deepseek models, we want to include tools in the system prompt instead
if model_config.model_name.contains("deepseek-r1") {
let tool_instructions = if !tools.is_empty() {
let tool_descriptions: Vec<String> = tools.iter()
.map(|tool| format!("- {}: {}", tool.name, tool.description))
.collect();

// println!("\nTools being provided:\n{}", tool_descriptions.join("\n"));

format!(
"\n\nAvailable tools:\n{}\n\n# Reminder: Instructions for Tool Use\n\nTool uses are formatted using XML-style tags. The tool name is enclosed in opening and closing tags. Here's the structure:\n\n<tool_name>\n<parameter1_name>value1</parameter1_name>\n<parameter2_name>value2</parameter2_name>\n...\n</tool_name>\n\nFor example, to use the shell tool:\n\n<developer__shell>\n<command>ls -l</command>\n</developer__shell>\n\nAlways adhere to this format for all tool uses to ensure proper parsing and execution.\n",
tool_descriptions.join("\n")
)
} else {
String::new()
};

let enhanced_system = format!("{}{}", system, tool_instructions);
println!("\nEnhanced system prompt:\n{}", enhanced_system);

// Find the last user message and enhance it
let mut modified_messages = messages.to_vec();
if let Some(last_user_msg_idx) = modified_messages.iter().rposition(|msg| {
// Only consider user messages that don't have a tool_call_id
msg.role == Role::User && !msg.content.iter().any(|content| {
matches!(content, MessageContent::ToolResponse(_))
})
}) {
let last_user_msg = &modified_messages[last_user_msg_idx];
// Get the text content from the last user message
let user_text = last_user_msg.content.iter().find_map(|content| {
if let MessageContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
}).unwrap_or_default();

// Create new message with enhanced system prompt prepended
let enhanced_msg = Message::user().with_text(format!("{}\n{}", enhanced_system, user_text));
modified_messages[last_user_msg_idx] = enhanced_msg;
}

let payload = create_request(
model_config,
"", // Empty system prompt since we included it in the user message
&modified_messages,
&[], // Pass empty tools array since we're handling them in the system prompt
&super::utils::ImageFormat::OpenAi,
)?;
return Ok(payload);
}

let mut payload = create_request(
model_config,
system,
Expand Down Expand Up @@ -198,13 +380,38 @@ impl Provider for OpenRouterProvider {
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Create the base payload
let payload = create_request_based_on_model(&self.model, system, messages, tools)?;

let mut payload = create_request_based_on_model(&self.model, system, messages, tools)?;
// payload["provider"] = json!({"order": ["Avian"], "allow_fallbacks": false});
println!("Request Payload: {}\n", serde_json::to_string_pretty(&payload).unwrap());
// Make request
let response = self.post(payload.clone()).await?;


// Parse response - special handling for deepseek models
let message = if self.model.model_name.contains("deepseek-r1") {
let content = response["choices"][0]["message"]["content"]
.as_str()
.unwrap_or_default()
.to_string();

println!("Response payload:\n{}", serde_json::to_string_pretty(&response).unwrap());
println!("\nExtracted content:\n{}", content);

// Attempt to parse multiple tool usability from the content
let calls = parse_tool_usages(&content);

if calls.is_empty() {
// No tool calls found, treat entire content as text
create_text_message(content)
} else {
// For demonstration, return the FIRST tool call.
// If you want to handle multiple calls, see the parse_tool_usages doc.
calls[0].clone()
}
} else {
response_to_message(response.clone())?
};

// Parse response
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
Expand Down
Loading