forked from TabbyML/tabby
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(webserver): add gitlab SSO support (TabbyML#2213)
* feat(webserver): add gitlab SSO support * Fix create_authorization_url * Error on email failure * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
59b70e6
commit 27557b5
Showing
7 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ enum LicenseType { | |
enum OAuthProvider { | ||
GITHUB | ||
GITLAB | ||
} | ||
|
||
enum RepositoryKind { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
use std::sync::Arc; | ||
|
||
use anyhow::Result; | ||
use async_trait::async_trait; | ||
use serde::Deserialize; | ||
use tabby_schema::auth::{AuthenticationService, OAuthCredential, OAuthProvider}; | ||
|
||
use super::OAuthClient; | ||
use crate::bail; | ||
|
||
#[derive(Debug, Deserialize)] | ||
#[allow(dead_code)] | ||
struct GitlabOAuthResponse { | ||
#[serde(default)] | ||
access_token: String, | ||
#[serde(default)] | ||
scope: String, | ||
#[serde(default)] | ||
token_type: String, | ||
|
||
#[serde(default)] | ||
expires_in: i32, | ||
#[serde(default)] | ||
created_at: u64, | ||
#[serde(default)] | ||
error: Option<String>, | ||
#[serde(default)] | ||
error_description: Option<String>, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
#[allow(dead_code)] | ||
struct GitlabUserEmail { | ||
#[serde(default)] | ||
email: String, | ||
error: Option<String>, | ||
} | ||
|
||
pub struct GitlabClient { | ||
client: reqwest::Client, | ||
auth: Arc<dyn AuthenticationService>, | ||
} | ||
|
||
impl GitlabClient { | ||
pub fn new(auth: Arc<dyn AuthenticationService>) -> Self { | ||
Self { | ||
client: reqwest::Client::new(), | ||
auth, | ||
} | ||
} | ||
|
||
async fn read_credential(&self) -> Result<OAuthCredential> { | ||
match self | ||
.auth | ||
.read_oauth_credential(OAuthProvider::Gitlab) | ||
.await? | ||
{ | ||
Some(credential) => Ok(credential), | ||
None => bail!("No Gitlab OAuth credential found"), | ||
} | ||
} | ||
|
||
async fn exchange_access_token( | ||
&self, | ||
code: String, | ||
credential: OAuthCredential, | ||
redirect_uri: String, | ||
) -> Result<GitlabOAuthResponse> { | ||
let params: [(&str, &str); 5] = [ | ||
("client_id", &credential.client_id), | ||
("client_secret", &credential.client_secret), | ||
("code", &code), | ||
("grant_type", "authorization_code"), | ||
("redirect_uri", &redirect_uri), | ||
]; | ||
let resp = self | ||
.client | ||
.post("https://gitlab.com/oauth/token") | ||
.header(reqwest::header::ACCEPT, "application/json") | ||
.form(¶ms) | ||
.send() | ||
.await? | ||
.json::<GitlabOAuthResponse>() | ||
.await?; | ||
|
||
Ok(resp) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl OAuthClient for GitlabClient { | ||
async fn fetch_user_email(&self, code: String) -> Result<String> { | ||
let credentials = self.read_credential().await?; | ||
let redirect_uri = self.auth.oauth_callback_url(OAuthProvider::Gitlab).await?; | ||
let token_resp = self | ||
.exchange_access_token(code, credentials, redirect_uri) | ||
.await?; | ||
|
||
if let Some(err) = token_resp.error { | ||
bail!( | ||
"Error while exchanging access token: {err} {}", | ||
token_resp | ||
.error_description | ||
.map(|s| format!("({s})")) | ||
.unwrap_or_default() | ||
); | ||
} | ||
|
||
if token_resp.access_token.is_empty() { | ||
bail!("Empty access token from Gitlab OAuth"); | ||
} | ||
|
||
let resp = self | ||
.client | ||
.get("https://gitlab.com/api/v4/user") | ||
.header(reqwest::header::USER_AGENT, "Tabby") | ||
.header(reqwest::header::ACCEPT, "application/vnd.gitlab+json") | ||
.header( | ||
reqwest::header::AUTHORIZATION, | ||
format!("Bearer {}", token_resp.access_token), | ||
) | ||
.send() | ||
.await?; | ||
|
||
let email = resp.json::<GitlabUserEmail>().await?; | ||
if let Some(error) = email.error { | ||
bail!("{error}"); | ||
} | ||
Ok(email.email) | ||
} | ||
|
||
async fn get_authorization_url(&self) -> Result<String> { | ||
let credentials = self.read_credential().await?; | ||
let redirect_uri = self.auth.oauth_callback_url(OAuthProvider::Gitlab).await?; | ||
create_authorization_url(&credentials.client_id, &redirect_uri) | ||
} | ||
} | ||
|
||
fn create_authorization_url(client_id: &str, redirect_uri: &str) -> Result<String> { | ||
let mut url = reqwest::Url::parse("https://gitlab.com/oauth/authorize")?; | ||
let params = vec![ | ||
("client_id", client_id), | ||
("response_type", "code"), | ||
("scope", "api"), | ||
("redirect_uri", redirect_uri), | ||
]; | ||
for (k, v) in params { | ||
url.query_pairs_mut().append_pair(k, v); | ||
} | ||
Ok(url.to_string()) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::create_authorization_url; | ||
|
||
#[test] | ||
fn test_create_authorization_url() { | ||
let url = | ||
create_authorization_url("client_id", "http://localhost:8080/oauth/callback/gitlab") | ||
.unwrap(); | ||
assert_eq!(url, "https://gitlab.com/oauth/authorize?client_id=client_id&response_type=code&scope=api&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Foauth%2Fcallback%2Fgitlab"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters