Skip to content

Commit

Permalink
feat: adding oauth for gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
damienrj committed Jan 31, 2025
1 parent 06a2464 commit ab23c56
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 110 deletions.
1 change: 1 addition & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ serde_yaml = "0.9.34"
once_cell = "1.20.2"
dirs = "6.0.0"
rand = "0.8.5"
oauth2 = { version = "4.4", features = ["reqwest"] }

[dev-dependencies]
criterion = "0.5"
Expand Down
118 changes: 107 additions & 11 deletions crates/goose/src/providers/google.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::oauth::{self, DEFAULT_REDIRECT_URL};
use super::errors::ProviderError;
use crate::message::Message;
use crate::model::ModelConfig;
Expand All @@ -8,10 +9,15 @@ use anyhow::Result;
use async_trait::async_trait;
use mcp_core::tool::Tool;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;

pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com";
pub const GOOGLE_AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth";
pub const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
const DEFAULT_CLIENT_ID: &str = "goose-cli";
const DEFAULT_SCOPES: &[&str] = &["https://www.googleapis.com/auth/generative-language.retriever"];
pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.0-flash-exp";
pub const GOOGLE_KNOWN_MODELS: &[&str] = &[
"models/gemini-1.5-pro-latest",
Expand All @@ -24,12 +30,38 @@ pub const GOOGLE_KNOWN_MODELS: &[&str] = &[

pub const GOOGLE_DOC_URL: &str = "https://ai.google/get-started/our-models/";

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GoogleAuth {
ApiKey(String),
OAuth {
client_id: String,
client_secret: String,
redirect_url: String,
scopes: Vec<String>,
},
}

impl GoogleAuth {
pub fn api_key(key: String) -> Self {
Self::ApiKey(key)
}

pub fn oauth(client_id: String, client_secret: String) -> Self {
Self::OAuth {
client_id,
client_secret,
redirect_url: DEFAULT_REDIRECT_URL.to_string(),
scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
}
}
}

#[derive(Debug, serde::Serialize)]
pub struct GoogleProvider {
#[serde(skip)]
client: Client,
host: String,
api_key: String,
auth: GoogleAuth,
model: ModelConfig,
}

Expand All @@ -43,7 +75,6 @@ impl Default for GoogleProvider {
impl GoogleProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("GOOGLE_API_KEY")?;
let host: String = config
.get("GOOGLE_HOST")
.unwrap_or_else(|_| GOOGLE_API_HOST.to_string());
Expand All @@ -52,29 +83,91 @@ impl GoogleProvider {
.timeout(Duration::from_secs(600))
.build()?;

// First try API key authentication
if let Ok(api_key) = config.get_secret("GOOGLE_API_KEY") {
return Ok(Self {
client,
host,
auth: GoogleAuth::api_key(api_key),
model,
});
}

// Fall back to OAuth if both client ID and secret are configured
let client_id = config.get("GOOGLE_CLIENT_ID")
.map_err(|_| anyhow::anyhow!("GOOGLE_CLIENT_ID not set"))?;

let client_secret = config.get_secret("GOOGLE_CLIENT_SECRET")
.map_err(|_| anyhow::anyhow!("GOOGLE_CLIENT_SECRET not set"))?;

let redirect_url = config
.get("GOOGLE_REDIRECT_URL")
.unwrap_or_else(|_| DEFAULT_REDIRECT_URL.to_string());

let scopes = DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect();

Ok(Self {
client,
host,
api_key,
auth: GoogleAuth::OAuth {
client_id,
client_secret,
redirect_url,
scopes,
},
model,
})
}

async fn ensure_auth_header(&self) -> Result<String, ProviderError> {
match &self.auth {
GoogleAuth::ApiKey(key) => Ok(format!("Bearer {}", key)),
GoogleAuth::OAuth {
client_id,
client_secret,
scopes,
.. // Ignore redirect_url as we're using the default
} => {
oauth::get_oauth_token_with_endpoints_async(
GOOGLE_AUTH_ENDPOINT,
GOOGLE_TOKEN_ENDPOINT,
client_id,
client_secret,
scopes,
)
.await
.map_err(|e| ProviderError::Authentication(format!("Failed to get OAuth token: {}", e)))
.map(|token| format!("Bearer {}", token))
}
}
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let auth = self.ensure_auth_header().await?;
let url = format!(
"{}/v1beta/models/{}:generateContent?key={}",
"{}/v1beta/models/{}:generateContent",
self.host.trim_end_matches('/'),
self.model.model_name,
self.api_key
);

let response = self
// Add auth either as query param for API key or header for OAuth
let mut request = self
.client
.post(&url)
.header("CONTENT_TYPE", "application/json")
.json(&payload)
.send()
.await?;
.header("Content-Type", "application/json");

match &self.auth {
GoogleAuth::ApiKey(_) => {
// Remove "Bearer " prefix for API key and pass as query param
let api_key = auth.trim_start_matches("Bearer ").to_string();
request = request.query(&[("key", api_key)]);
}
GoogleAuth::OAuth { .. } => {
request = request.header("Authorization", auth);
}
}

let response = request.json(&payload).send().await?;

let status = response.status();
let payload: Option<Value> = response.json().await.ok();
Expand Down Expand Up @@ -128,8 +221,11 @@ impl Provider for GoogleProvider {
GOOGLE_KNOWN_MODELS.iter().map(|&s| s.to_string()).collect(),
GOOGLE_DOC_URL,
vec![
ConfigKey::new("GOOGLE_API_KEY", true, true, None),
ConfigKey::new("GOOGLE_API_KEY", false, true, None),
ConfigKey::new("GOOGLE_HOST", false, false, Some(GOOGLE_API_HOST)),
ConfigKey::new("GOOGLE_CLIENT_ID", false, false, None),
ConfigKey::new("GOOGLE_CLIENT_SECRET", false, true, None),
ConfigKey::new("GOOGLE_REDIRECT_URL", false, false, Some(DEFAULT_REDIRECT_URL)),
],
)
}
Expand Down
Loading

0 comments on commit ab23c56

Please sign in to comment.