From ab23c5686c91aca8279fc6295766656414833aed Mon Sep 17 00:00:00 2001 From: damienrj Date: Fri, 31 Jan 2025 09:55:15 -0800 Subject: [PATCH] feat: adding oauth for gemini --- crates/goose/Cargo.toml | 1 + crates/goose/src/providers/google.rs | 118 +++++++++++-- crates/goose/src/providers/oauth.rs | 246 ++++++++++++++++----------- 3 files changed, 255 insertions(+), 110 deletions(-) diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 053f55ab3..cf38ee779 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -60,6 +60,7 @@ serde_yaml = "0.9.34" once_cell = "1.20.2" dirs = "6.0.0" rand = "0.8.5" +oauth2 = { version = "4.4", features = ["reqwest"] } [dev-dependencies] criterion = "0.5" diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index bf7d73bda..c5bfe4d40 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,3 +1,4 @@ +use super::oauth::{self, DEFAULT_REDIRECT_URL}; use super::errors::ProviderError; use crate::message::Message; use crate::model::ModelConfig; @@ -8,10 +9,15 @@ use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; use reqwest::{Client, StatusCode}; +use serde::{Deserialize, Serialize}; use serde_json::Value; use std::time::Duration; pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com"; +pub const GOOGLE_AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth"; +pub const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token"; +const DEFAULT_CLIENT_ID: &str = "goose-cli"; +const DEFAULT_SCOPES: &[&str] = &["https://www.googleapis.com/auth/generative-language.retriever"]; pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.0-flash-exp"; pub const GOOGLE_KNOWN_MODELS: &[&str] = &[ "models/gemini-1.5-pro-latest", @@ -24,12 +30,38 @@ pub const GOOGLE_KNOWN_MODELS: &[&str] = &[ pub const GOOGLE_DOC_URL: &str = "https://ai.google/get-started/our-models/"; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GoogleAuth { + ApiKey(String), + OAuth { + client_id: String, + client_secret: String, + redirect_url: String, + scopes: Vec, + }, +} + +impl GoogleAuth { + pub fn api_key(key: String) -> Self { + Self::ApiKey(key) + } + + pub fn oauth(client_id: String, client_secret: String) -> Self { + Self::OAuth { + client_id, + client_secret, + redirect_url: DEFAULT_REDIRECT_URL.to_string(), + scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(), + } + } +} + #[derive(Debug, serde::Serialize)] pub struct GoogleProvider { #[serde(skip)] client: Client, host: String, - api_key: String, + auth: GoogleAuth, model: ModelConfig, } @@ -43,7 +75,6 @@ impl Default for GoogleProvider { impl GoogleProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("GOOGLE_API_KEY")?; let host: String = config .get("GOOGLE_HOST") .unwrap_or_else(|_| GOOGLE_API_HOST.to_string()); @@ -52,29 +83,91 @@ impl GoogleProvider { .timeout(Duration::from_secs(600)) .build()?; + // First try API key authentication + if let Ok(api_key) = config.get_secret("GOOGLE_API_KEY") { + return Ok(Self { + client, + host, + auth: GoogleAuth::api_key(api_key), + model, + }); + } + + // Fall back to OAuth if both client ID and secret are configured + let client_id = config.get("GOOGLE_CLIENT_ID") + .map_err(|_| anyhow::anyhow!("GOOGLE_CLIENT_ID not set"))?; + + let client_secret = config.get_secret("GOOGLE_CLIENT_SECRET") + .map_err(|_| anyhow::anyhow!("GOOGLE_CLIENT_SECRET not set"))?; + + let redirect_url = config + .get("GOOGLE_REDIRECT_URL") + .unwrap_or_else(|_| DEFAULT_REDIRECT_URL.to_string()); + + let scopes = DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(); + Ok(Self { client, host, - api_key, + auth: GoogleAuth::OAuth { + client_id, + client_secret, + redirect_url, + scopes, + }, model, }) } + async fn ensure_auth_header(&self) -> Result { + match &self.auth { + GoogleAuth::ApiKey(key) => Ok(format!("Bearer {}", key)), + GoogleAuth::OAuth { + client_id, + client_secret, + scopes, + .. // Ignore redirect_url as we're using the default + } => { + oauth::get_oauth_token_with_endpoints_async( + GOOGLE_AUTH_ENDPOINT, + GOOGLE_TOKEN_ENDPOINT, + client_id, + client_secret, + scopes, + ) + .await + .map_err(|e| ProviderError::Authentication(format!("Failed to get OAuth token: {}", e))) + .map(|token| format!("Bearer {}", token)) + } + } + } + async fn post(&self, payload: Value) -> Result { + let auth = self.ensure_auth_header().await?; let url = format!( - "{}/v1beta/models/{}:generateContent?key={}", + "{}/v1beta/models/{}:generateContent", self.host.trim_end_matches('/'), self.model.model_name, - self.api_key ); - let response = self + // Add auth either as query param for API key or header for OAuth + let mut request = self .client .post(&url) - .header("CONTENT_TYPE", "application/json") - .json(&payload) - .send() - .await?; + .header("Content-Type", "application/json"); + + match &self.auth { + GoogleAuth::ApiKey(_) => { + // Remove "Bearer " prefix for API key and pass as query param + let api_key = auth.trim_start_matches("Bearer ").to_string(); + request = request.query(&[("key", api_key)]); + } + GoogleAuth::OAuth { .. } => { + request = request.header("Authorization", auth); + } + } + + let response = request.json(&payload).send().await?; let status = response.status(); let payload: Option = response.json().await.ok(); @@ -128,8 +221,11 @@ impl Provider for GoogleProvider { GOOGLE_KNOWN_MODELS.iter().map(|&s| s.to_string()).collect(), GOOGLE_DOC_URL, vec![ - ConfigKey::new("GOOGLE_API_KEY", true, true, None), + ConfigKey::new("GOOGLE_API_KEY", false, true, None), ConfigKey::new("GOOGLE_HOST", false, false, Some(GOOGLE_API_HOST)), + ConfigKey::new("GOOGLE_CLIENT_ID", false, false, None), + ConfigKey::new("GOOGLE_CLIENT_SECRET", false, true, None), + ConfigKey::new("GOOGLE_REDIRECT_URL", false, false, Some(DEFAULT_REDIRECT_URL)), ], ) } diff --git a/crates/goose/src/providers/oauth.rs b/crates/goose/src/providers/oauth.rs index 00bc66733..9baa98e35 100644 --- a/crates/goose/src/providers/oauth.rs +++ b/crates/goose/src/providers/oauth.rs @@ -1,18 +1,20 @@ use anyhow::Result; use axum::{extract::Query, response::Html, routing::get, Router}; use base64::Engine; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, Duration, Utc}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha2::Digest; use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc}; use tokio::sync::{oneshot, Mutex as TokioMutex}; +use url::Url; lazy_static! { static ref OAUTH_MUTEX: TokioMutex<()> = TokioMutex::new(()); } -use url::Url; + +pub const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; #[derive(Debug, Clone)] struct OidcEndpoints { @@ -20,58 +22,59 @@ struct OidcEndpoints { token_endpoint: String, } -#[derive(Serialize, Deserialize)] -struct TokenData { +#[derive(Debug, Serialize, Deserialize)] +struct TokenCache { access_token: String, expires_at: Option>, } -struct TokenCache { - cache_path: PathBuf, +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: Option, } -fn get_base_path() -> PathBuf { - const BASE_PATH: &str = ".config/goose/databricks/oauth"; - let home_dir = std::env::var("HOME").expect("HOME environment variable not set"); - PathBuf::from(home_dir).join(BASE_PATH) +fn get_cache_path(client_id: &str, scopes: &[String]) -> PathBuf { + let mut hasher = sha2::Sha256::new(); + hasher.update(client_id.as_bytes()); + hasher.update(scopes.join(",").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + + let base_path = dirs::config_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("goose/google/oauth"); + + fs::create_dir_all(&base_path).unwrap_or_default(); + base_path.join(format!("{}.json", hash)) } -impl TokenCache { - fn new(host: &str, client_id: &str, scopes: &[String]) -> Self { - let mut hasher = sha2::Sha256::new(); - hasher.update(host.as_bytes()); - hasher.update(client_id.as_bytes()); - hasher.update(scopes.join(",").as_bytes()); - let hash = format!("{:x}", hasher.finalize()); - - fs::create_dir_all(get_base_path()).unwrap(); - let cache_path = get_base_path().join(format!("{}.json", hash)); - - Self { cache_path } - } - - fn load_token(&self) -> Option { - if let Ok(contents) = fs::read_to_string(&self.cache_path) { - if let Ok(token_data) = serde_json::from_str::(&contents) { - if let Some(expires_at) = token_data.expires_at { - if expires_at > Utc::now() { - return Some(token_data); - } - } else { - return Some(token_data); +fn load_cached_token(client_id: &str, scopes: &[String]) -> Option { + let cache_path = get_cache_path(client_id, scopes); + if let Ok(contents) = fs::read_to_string(cache_path) { + if let Ok(cache) = serde_json::from_str::(&contents) { + if let Some(expires_at) = cache.expires_at { + if expires_at > Utc::now() { + return Some(cache.access_token); } } } - None } + None +} - fn save_token(&self, token_data: &TokenData) -> Result<()> { - if let Some(parent) = self.cache_path.parent() { - fs::create_dir_all(parent)?; - } - let contents = serde_json::to_string(token_data)?; - fs::write(&self.cache_path, contents)?; - Ok(()) +fn save_token_cache(client_id: &str, scopes: &[String], token: &str, expires_in: Option) { + let expires_at = expires_in.map(|secs| { + Utc::now() + Duration::seconds(secs as i64) + }); + + let cache = TokenCache { + access_token: token.to_string(), + expires_at, + }; + + let cache_path = get_cache_path(client_id, scopes); + if let Ok(contents) = serde_json::to_string(&cache) { + fs::write(cache_path, contents).unwrap_or_default(); } } @@ -112,6 +115,7 @@ async fn get_workspace_endpoints(host: &str) -> Result { struct OAuthFlow { endpoints: OidcEndpoints, client_id: String, + client_secret: String, redirect_url: String, scopes: Vec, state: String, @@ -122,12 +126,14 @@ impl OAuthFlow { fn new( endpoints: OidcEndpoints, client_id: String, + client_secret: String, redirect_url: String, scopes: Vec, ) -> Self { Self { endpoints, client_id, + client_secret, redirect_url, scopes, state: nanoid::nanoid!(16), @@ -158,52 +164,32 @@ impl OAuthFlow { ) } - async fn exchange_code_for_token(&self, code: &str) -> Result { + async fn exchange_code(&self, code: &str) -> Result { + let client = reqwest::Client::new(); let params = [ - ("grant_type", "authorization_code"), + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), ("code", code), - ("redirect_uri", &self.redirect_url), - ("code_verifier", &self.verifier), - ("client_id", &self.client_id), + ("redirect_uri", self.redirect_url.as_str()), + ("grant_type", "authorization_code"), + ("code_verifier", self.verifier.as_str()), ]; - let client = reqwest::Client::new(); - let resp = client + let response = client .post(&self.endpoints.token_endpoint) - .header("Content-Type", "application/x-www-form-urlencoded") .form(¶ms) .send() .await?; - if !resp.status().is_success() { - let err_text = resp.text().await?; - return Err(anyhow::anyhow!( - "Failed to exchange code for token: {}", - err_text - )); + if !response.status().is_success() { + let error = response.text().await?; + return Err(anyhow::anyhow!("Failed to exchange code for token: {}", error)); } - let token_response: Value = resp.json().await?; - let access_token = token_response - .get("access_token") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))? - .to_string(); - - let expires_in = token_response - .get("expires_in") - .and_then(|v| v.as_u64()) - .unwrap_or(3600); - - let expires_at = Utc::now() + chrono::Duration::seconds(expires_in as i64); - - Ok(TokenData { - access_token, - expires_at: Some(expires_at), - }) + response.json().await.map_err(Into::into) } - async fn execute(&self) -> Result { + async fn execute(&self) -> Result { // Create a channel that will send the auth code from the app process let (tx, rx) = oneshot::channel(); let state = self.state.clone(); @@ -275,7 +261,25 @@ impl OAuthFlow { server_handle.abort(); // Exchange the code for a token - self.exchange_code_for_token(&code).await + self.exchange_code(&code).await + } + + fn new_with_endpoints( + endpoints: OidcEndpoints, + client_id: String, + client_secret: String, + redirect_url: String, + scopes: Vec, + ) -> Self { + Self { + endpoints, + client_id, + client_secret, + redirect_url, + scopes, + state: nanoid::nanoid!(16), + verifier: nanoid::nanoid!(64), + } } } @@ -285,31 +289,71 @@ pub(crate) async fn get_oauth_token_async( redirect_url: &str, scopes: &[String], ) -> Result { - // Acquire the global mutex to ensure only one OAuth flow runs at a time - let _guard = OAUTH_MUTEX.lock().await; - - let token_cache = TokenCache::new(host, client_id, scopes); - - // Try cache first - if let Some(token) = token_cache.load_token() { - return Ok(token.access_token); + // Try to load from cache first + if let Some(token) = load_cached_token(client_id, scopes) { + return Ok(token); } - // Get endpoints and execute flow + // Get OIDC configuration let endpoints = get_workspace_endpoints(host).await?; + + // If no valid cached token, perform OAuth flow let flow = OAuthFlow::new( endpoints, client_id.to_string(), + client_id.to_string(), redirect_url.to_string(), scopes.to_vec(), ); - // Execute the OAuth flow and get token - let token = flow.execute().await?; + let token_response = flow.execute().await?; + + // Cache the token before returning + save_token_cache( + client_id, + scopes, + &token_response.access_token, + token_response.expires_in, + ); - // Cache and return - token_cache.save_token(&token)?; - Ok(token.access_token) + Ok(token_response.access_token) +} + +pub async fn get_oauth_token_with_endpoints_async( + auth_endpoint: &str, + token_endpoint: &str, + client_id: &str, + client_secret: &str, + scopes: &[String], +) -> Result { + // Try to load from cache first + if let Some(token) = load_cached_token(client_id, scopes) { + return Ok(token); + } + + // If no valid cached token, perform OAuth flow + let flow = OAuthFlow::new_with_endpoints( + OidcEndpoints { + authorization_endpoint: auth_endpoint.to_string(), + token_endpoint: token_endpoint.to_string(), + }, + client_id.to_string(), + client_secret.to_string(), + DEFAULT_REDIRECT_URL.to_string(), + scopes.to_vec(), + ); + + let token_response = flow.execute().await?; + + // Cache the token before returning + save_token_cache( + client_id, + scopes, + &token_response.access_token, + token_response.expires_in, + ); + + Ok(token_response.access_token) } #[cfg(test)] @@ -348,21 +392,25 @@ mod tests { #[test] fn test_token_cache() -> Result<()> { - let cache = TokenCache::new( - "https://example.com", - "test-client", - &["scope1".to_string()], - ); + let cache = TokenCache { + access_token: "test-token".to_string(), + expires_at: Some(Utc::now() + Duration::seconds(3600)), + }; - let token_data = TokenData { + let token_data = TokenResponse { access_token: "test-token".to_string(), - expires_at: Some(Utc::now() + chrono::Duration::hours(1)), + expires_in: Some(3600), }; - cache.save_token(&token_data)?; + save_token_cache( + "https://example.com", + &["scope1".to_string()], + &token_data.access_token, + token_data.expires_in, + ); - let loaded_token = cache.load_token().unwrap(); - assert_eq!(loaded_token.access_token, token_data.access_token); + let loaded_token = load_cached_token("https://example.com", &["scope1".to_string()]).unwrap(); + assert_eq!(loaded_token, token_data.access_token); Ok(()) }