diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 053f55ab3..a706d6a8e 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -61,6 +61,11 @@ once_cell = "1.20.2" dirs = "6.0.0" rand = "0.8.5" +# For Bedrock provider +aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } +aws-smithy-types = "1.2.12" +aws-sdk-bedrockruntime = "1.72.0" + [dev-dependencies] criterion = "0.5" tempfile = "3.15.0" diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs new file mode 100644 index 000000000..ad40a321c --- /dev/null +++ b/crates/goose/src/providers/bedrock.rs @@ -0,0 +1,162 @@ +use anyhow::Result; +use async_trait::async_trait; +use aws_sdk_bedrockruntime::operation::converse::ConverseError; +use aws_sdk_bedrockruntime::{types as bedrock, Client}; +use mcp_core::Tool; + +use super::base::{Provider, ProviderMetadata, ProviderUsage}; +use super::errors::ProviderError; +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::utils::emit_debug_trace; + +// Import the migrated helper functions from providers/formats/bedrock.rs +use super::formats::bedrock::{ + from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config, +}; + +pub const BEDROCK_DOC_LINK: &str = + "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html"; + +pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0"; +pub const BEDROCK_KNOWN_MODELS: &[&str] = &[ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20241022-v2:0", +]; + +#[derive(Debug, serde::Serialize)] +pub struct BedrockProvider { + #[serde(skip)] + client: Client, + model: ModelConfig, +} + +impl BedrockProvider { + pub fn from_env(model: ModelConfig) -> Result { + let sdk_config = futures::executor::block_on(aws_config::load_from_env()); + let client = Client::new(&sdk_config); + + Ok(Self { client, model }) + } +} + +impl Default for BedrockProvider { + fn default() -> Self { + let model = ModelConfig::new(BedrockProvider::metadata().default_model); + BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider") + } +} + +#[async_trait] +impl Provider for BedrockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "bedrock", + "Amazon Bedrock", + "Run models through Amazon Bedrock. You may have to set AWS_ACCESS_KEY_ID, AWS_ACCESS_KEY, and AWS_REGION as env vars before configuring.", + BEDROCK_DEFAULT_MODEL, + BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(), + BEDROCK_DOC_LINK, + vec![], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let model_name = &self.model.model_name; + + let mut request = self + .client + .converse() + .system(bedrock::SystemContentBlock::Text(system.to_string())) + .model_id(model_name.to_string()) + .set_messages(Some( + messages + .iter() + .map(to_bedrock_message) + .collect::>()?, + )); + + if !tools.is_empty() { + request = request.tool_config(to_bedrock_tool_config(tools)?); + } + + let response = request.send().await; + + let response = match response { + Ok(response) => response, + Err(err) => { + return Err(match err.into_service_error() { + ConverseError::AccessDeniedException(err) => { + ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)) + } + ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded( + format!("Failed to call Bedrock: {:?}", err), + ), + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {:?}", + err + )) + } + ConverseError::ModelErrorException(err) => { + ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err)) + } + err => { + ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,)) + } + }); + } + }; + + let message = match response.output { + Some(bedrock::ConverseOutput::Message(message)) => message, + _ => { + return Err(ProviderError::RequestFailed( + "No output from Bedrock".to_string(), + )) + } + }; + + let usage = response + .usage + .as_ref() + .map(from_bedrock_usage) + .unwrap_or_default(); + + let message = from_bedrock_message(&message)?; + + // Add debug trace with input context + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + Ok((message, provider_usage)) + } +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index ed169aa7e..d17fb8893 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -2,6 +2,7 @@ use super::{ anthropic::AnthropicProvider, azure::AzureProvider, base::{Provider, ProviderMetadata}, + bedrock::BedrockProvider, databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, @@ -16,6 +17,7 @@ pub fn providers() -> Vec { vec![ AnthropicProvider::metadata(), AzureProvider::metadata(), + BedrockProvider::metadata(), DatabricksProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), @@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(OpenAiProvider::from_env(model)?)), "anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)), "azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)), + "bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)), "databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs new file mode 100644 index 000000000..812fda263 --- /dev/null +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -0,0 +1,270 @@ +use std::collections::HashMap; +use std::path::Path; + +use anyhow::{anyhow, bail, Result}; +use aws_sdk_bedrockruntime::types as bedrock; +use aws_smithy_types::{Document, Number}; +use chrono::Utc; +use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult}; +use serde_json::Value; + +use super::super::base::Usage; +use crate::message::{Message, MessageContent}; + +pub fn to_bedrock_message(message: &Message) -> Result { + bedrock::Message::builder() + .role(to_bedrock_role(&message.role)) + .set_content(Some( + message + .content + .iter() + .map(to_bedrock_message_content) + .collect::>()?, + )) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err)) +} + +pub fn to_bedrock_message_content(content: &MessageContent) -> Result { + Ok(match content { + MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), + MessageContent::Image(_) => { + bail!("Image content is not supported by Bedrock provider yet") + } + MessageContent::ToolRequest(tool_req) => { + let tool_use_id = tool_req.id.to_string(); + let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .name(call.name.to_string()) + .input(to_bedrock_json(&call.arguments)) + .build() + } else { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .build() + }?; + bedrock::ContentBlock::ToolUse(tool_use) + } + MessageContent::ToolResponse(tool_res) => { + let content = match &tool_res.tool_result { + Ok(content) => Some( + content + .iter() + .map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c)) + .collect::>()?, + ), + Err(_) => None, + }; + bedrock::ContentBlock::ToolResult( + bedrock::ToolResultBlock::builder() + .tool_use_id(tool_res.id.to_string()) + .status(if content.is_some() { + bedrock::ToolResultStatus::Success + } else { + bedrock::ToolResultStatus::Error + }) + .set_content(content) + .build()?, + ) + } + }) +} + +pub fn to_bedrock_tool_result_content_block( + tool_use_id: &str, + content: &Content, +) -> Result { + Ok(match content { + Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()), + Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"), + Content::Resource(resource) => bedrock::ToolResultContentBlock::Document( + to_bedrock_document(tool_use_id, &resource.resource)?, + ), + }) +} + +pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole { + match role { + Role::User => bedrock::ConversationRole::User, + Role::Assistant => bedrock::ConversationRole::Assistant, + } +} + +pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result { + Ok(bedrock::ToolConfiguration::builder() + .set_tools(Some( + tools.iter().map(to_bedrock_tool).collect::>()?, + )) + .build()?) +} + +pub fn to_bedrock_tool(tool: &Tool) -> Result { + Ok(bedrock::Tool::ToolSpec( + bedrock::ToolSpecification::builder() + .name(tool.name.to_string()) + .description(tool.description.to_string()) + .input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json( + &tool.input_schema, + ))) + .build()?, + )) +} + +pub fn to_bedrock_json(value: &Value) -> Document { + match value { + Value::Null => Document::Null, + Value::Bool(bool) => Document::Bool(*bool), + Value::Number(num) => { + if let Some(n) = num.as_u64() { + Document::Number(Number::PosInt(n)) + } else if let Some(n) = num.as_i64() { + Document::Number(Number::NegInt(n)) + } else if let Some(n) = num.as_f64() { + Document::Number(Number::Float(n)) + } else { + unreachable!() + } + } + Value::String(str) => Document::String(str.to_string()), + Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()), + Value::Object(obj) => Document::Object(HashMap::from_iter( + obj.into_iter() + .map(|(key, val)| (key.to_string(), to_bedrock_json(val))), + )), + } +} + +fn to_bedrock_document( + tool_use_id: &str, + content: &ResourceContents, +) -> Result { + let (uri, text) = match content { + ResourceContents::TextResourceContents { uri, text, .. } => (uri, text), + ResourceContents::BlobResourceContents { .. } => { + bail!("Blob resource content is not supported by Bedrock provider yet") + } + }; + + let filename = Path::new(uri) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or(uri); + + let (name, format) = match filename.split_once('.') { + Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt), + Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv), + Some((name, "md")) => (name, bedrock::DocumentFormat::Md), + Some((name, "html")) => (name, bedrock::DocumentFormat::Html), + Some((name, _)) => (name, bedrock::DocumentFormat::Txt), + _ => (filename, bedrock::DocumentFormat::Txt), + }; + + // Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.), + // and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make + // document names unique. + let name = format!("{tool_use_id}-{name}"); + + bedrock::DocumentBlock::builder() + .format(format) + .name(name) + .source(bedrock::DocumentSource::Bytes(text.as_bytes().into())) + .build() + .map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err)) +} + +pub fn from_bedrock_message(message: &bedrock::Message) -> Result { + let role = from_bedrock_role(message.role())?; + let content = message + .content() + .iter() + .map(from_bedrock_content_block) + .collect::>>()?; + let created = Utc::now().timestamp(); + + Ok(Message { + role, + content, + created, + }) +} + +pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result { + Ok(match block { + bedrock::ContentBlock::Text(text) => MessageContent::text(text), + bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request( + tool_use.tool_use_id.to_string(), + Ok(ToolCall::new( + tool_use.name.to_string(), + from_bedrock_json(&tool_use.input)?, + )), + ), + bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response( + tool_res.tool_use_id.to_string(), + if tool_res.content.is_empty() { + Err(ToolError::ExecutionError( + "Empty content for tool use from Bedrock".to_string(), + )) + } else { + tool_res + .content + .iter() + .map(from_bedrock_tool_result_content_block) + .collect::>>() + }, + ), + _ => bail!("Unsupported content block type from Bedrock"), + }) +} + +pub fn from_bedrock_tool_result_content_block( + content: &bedrock::ToolResultContentBlock, +) -> ToolResult { + Ok(match content { + bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()), + _ => { + return Err(ToolError::ExecutionError( + "Unsupported tool result from Bedrock".to_string(), + )) + } + }) +} + +pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result { + Ok(match role { + bedrock::ConversationRole::User => Role::User, + bedrock::ConversationRole::Assistant => Role::Assistant, + _ => bail!("Unknown role from Bedrock"), + }) +} + +pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + } +} + +pub fn from_bedrock_json(document: &Document) -> Result { + Ok(match document { + Document::Null => Value::Null, + Document::Bool(bool) => Value::Bool(*bool), + Document::Number(num) => match num { + Number::PosInt(i) => Value::Number((*i).into()), + Number::NegInt(i) => Value::Number((*i).into()), + Number::Float(f) => Value::Number( + serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?, + ), + }, + Document::String(str) => Value::String(str.clone()), + Document::Array(arr) => { + Value::Array(arr.iter().map(from_bedrock_json).collect::>()?) + } + Document::Object(obj) => Value::Object( + obj.iter() + .map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?))) + .collect::>()?, + ), + }) +} diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index 713468285..780f38488 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -1,3 +1,4 @@ pub mod anthropic; +pub mod bedrock; pub mod google; pub mod openai; diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index de6225767..634224fd7 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,6 +1,7 @@ pub mod anthropic; pub mod azure; pub mod base; +pub mod bedrock; pub mod databricks; pub mod errors; mod factory; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 6a5f4b9da..332f3ee76 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -3,7 +3,9 @@ use dotenv::dotenv; use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; -use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter}; +use goose::providers::{ + anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, +}; use mcp_core::content::Content; use mcp_core::tool::Tool; use std::collections::HashMap; @@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_bedrock_provider_long_term_credentials() -> Result<()> { + test_provider( + "Bedrock", + &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + None, + bedrock::BedrockProvider::default, + ) + .await +} + +#[tokio::test] +async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { + let env_mods = HashMap::from_iter([ + // Ensure to unset long-term credentials to use AWS Profile provider + ("AWS_ACCESS_KEY_ID", None), + ("AWS_SECRET_ACCESS_KEY", None), + ]); + + test_provider( + "Bedrock AWS Profile Credentials", + &["AWS_PROFILE"], + Some(env_mods), + bedrock::BedrockProvider::default, + ) + .await +} + #[tokio::test] async fn test_databricks_provider() -> Result<()> { test_provider( diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index d3702d5f8..f8375086f 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -8,7 +8,7 @@ use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{ - azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider, + azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, }; use goose::providers::{google::GoogleProvider, groq::GroqProvider}; @@ -18,6 +18,7 @@ enum ProviderType { Azure, OpenAi, Anthropic, + Bedrock, Databricks, Google, Groq, @@ -35,6 +36,7 @@ impl ProviderType { ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], + ProviderType::Bedrock => &["AWS_PROFILE", "AWS_REGION"], ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Google => &["GOOGLE_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"], @@ -66,6 +68,7 @@ impl ProviderType { ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?), ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), + ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?), ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), @@ -200,6 +203,16 @@ mod tests { .await } + #[tokio::test] + async fn test_truncate_agent_with_bedrock() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::Bedrock, + model: "anthropic.claude-3-5-sonnet-20241022-v2:0", + context_window: 200_000, + }) + .await + } + #[tokio::test] async fn test_truncate_agent_with_databricks() -> Result<()> { run_test_with_config(TestConfig {