Skip to content

Commit

Permalink
feat(webserver): add gitlab SSO support (TabbyML#2213)
Browse files Browse the repository at this point in the history
* feat(webserver): add gitlab SSO support

* Fix create_authorization_url

* Error on email failure

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
boxbeam and autofix-ci[bot] authored May 21, 2024
1 parent 59b70e6 commit 27557b5
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 0 deletions.
1 change: 1 addition & 0 deletions ee/tabby-schema/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum LicenseType {
enum OAuthProvider {
GITHUB
GOOGLE
GITLAB
}

enum RepositoryKind {
Expand Down
2 changes: 2 additions & 0 deletions ee/tabby-schema/src/dao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,15 @@ impl DbEnum for OAuthProvider {
match self {
OAuthProvider::Google => "google",
OAuthProvider::Github => "github",
OAuthProvider::Gitlab => "gitlab",
}
}

fn from_enum_str(s: &str) -> anyhow::Result<Self> {
match s {
"github" => Ok(OAuthProvider::Github),
"google" => Ok(OAuthProvider::Google),
"gitlab" => Ok(OAuthProvider::Gitlab),
_ => bail!("Invalid OAuth credential type"),
}
}
Expand Down
1 change: 1 addition & 0 deletions ee/tabby-schema/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ impl relay::NodeType for Invitation {
pub enum OAuthProvider {
Github,
Google,
Gitlab,
}

#[derive(GraphQLObject)]
Expand Down
164 changes: 164 additions & 0 deletions ee/tabby-webserver/src/oauth/gitlab.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
use std::sync::Arc;

use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use tabby_schema::auth::{AuthenticationService, OAuthCredential, OAuthProvider};

use super::OAuthClient;
use crate::bail;

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct GitlabOAuthResponse {
#[serde(default)]
access_token: String,
#[serde(default)]
scope: String,
#[serde(default)]
token_type: String,

#[serde(default)]
expires_in: i32,
#[serde(default)]
created_at: u64,
#[serde(default)]
error: Option<String>,
#[serde(default)]
error_description: Option<String>,
}

#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct GitlabUserEmail {
#[serde(default)]
email: String,
error: Option<String>,
}

pub struct GitlabClient {
client: reqwest::Client,
auth: Arc<dyn AuthenticationService>,
}

impl GitlabClient {
pub fn new(auth: Arc<dyn AuthenticationService>) -> Self {
Self {
client: reqwest::Client::new(),
auth,
}
}

async fn read_credential(&self) -> Result<OAuthCredential> {
match self
.auth
.read_oauth_credential(OAuthProvider::Gitlab)
.await?
{
Some(credential) => Ok(credential),
None => bail!("No Gitlab OAuth credential found"),
}
}

async fn exchange_access_token(
&self,
code: String,
credential: OAuthCredential,
redirect_uri: String,
) -> Result<GitlabOAuthResponse> {
let params: [(&str, &str); 5] = [
("client_id", &credential.client_id),
("client_secret", &credential.client_secret),
("code", &code),
("grant_type", "authorization_code"),
("redirect_uri", &redirect_uri),
];
let resp = self
.client
.post("https://gitlab.com/oauth/token")
.header(reqwest::header::ACCEPT, "application/json")
.form(&params)
.send()
.await?
.json::<GitlabOAuthResponse>()
.await?;

Ok(resp)
}
}

#[async_trait]
impl OAuthClient for GitlabClient {
async fn fetch_user_email(&self, code: String) -> Result<String> {
let credentials = self.read_credential().await?;
let redirect_uri = self.auth.oauth_callback_url(OAuthProvider::Gitlab).await?;
let token_resp = self
.exchange_access_token(code, credentials, redirect_uri)
.await?;

if let Some(err) = token_resp.error {
bail!(
"Error while exchanging access token: {err} {}",
token_resp
.error_description
.map(|s| format!("({s})"))
.unwrap_or_default()
);
}

if token_resp.access_token.is_empty() {
bail!("Empty access token from Gitlab OAuth");
}

let resp = self
.client
.get("https://gitlab.com/api/v4/user")
.header(reqwest::header::USER_AGENT, "Tabby")
.header(reqwest::header::ACCEPT, "application/vnd.gitlab+json")
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", token_resp.access_token),
)
.send()
.await?;

let email = resp.json::<GitlabUserEmail>().await?;
if let Some(error) = email.error {
bail!("{error}");
}
Ok(email.email)
}

async fn get_authorization_url(&self) -> Result<String> {
let credentials = self.read_credential().await?;
let redirect_uri = self.auth.oauth_callback_url(OAuthProvider::Gitlab).await?;
create_authorization_url(&credentials.client_id, &redirect_uri)
}
}

fn create_authorization_url(client_id: &str, redirect_uri: &str) -> Result<String> {
let mut url = reqwest::Url::parse("https://gitlab.com/oauth/authorize")?;
let params = vec![
("client_id", client_id),
("response_type", "code"),
("scope", "api"),
("redirect_uri", redirect_uri),
];
for (k, v) in params {
url.query_pairs_mut().append_pair(k, v);
}
Ok(url.to_string())
}

#[cfg(test)]
mod tests {
use super::create_authorization_url;

#[test]
fn test_create_authorization_url() {
let url =
create_authorization_url("client_id", "http://localhost:8080/oauth/callback/gitlab")
.unwrap();
assert_eq!(url, "https://gitlab.com/oauth/authorize?client_id=client_id&response_type=code&scope=api&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Foauth%2Fcallback%2Fgitlab");
}
}
4 changes: 4 additions & 0 deletions ee/tabby-webserver/src/oauth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod github;
mod gitlab;
mod google;

use std::sync::Arc;
Expand All @@ -9,6 +10,8 @@ use github::GithubClient;
use google::GoogleClient;
use tabby_schema::auth::{AuthenticationService, OAuthProvider};

use self::gitlab::GitlabClient;

#[async_trait]
pub trait OAuthClient: Send + Sync {
async fn fetch_user_email(&self, code: String) -> Result<String>;
Expand All @@ -20,6 +23,7 @@ pub fn new_oauth_client(
auth: Arc<dyn AuthenticationService>,
) -> Arc<dyn OAuthClient> {
match provider {
OAuthProvider::Gitlab => Arc::new(GitlabClient::new(auth)),
OAuthProvider::Google => Arc::new(GoogleClient::new(auth)),
OAuthProvider::Github => Arc::new(GithubClient::new(auth)),
}
Expand Down
18 changes: 18 additions & 0 deletions ee/tabby-webserver/src/routes/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub fn routes(state: Arc<dyn AuthenticationService>) -> Router {
.route("/providers", routing::get(providers_handler))
.route("/callback/github", routing::get(github_oauth_handler))
.route("/callback/google", routing::get(google_oauth_handler))
.route("/callback/gitlab", routing::get(gitlab_oauth_handler))
.with_state(state)
}

Expand Down Expand Up @@ -106,6 +107,23 @@ async fn google_oauth_handler(
)
}

#[derive(Deserialize)]
#[allow(dead_code)]
struct GitlabOAuthQueryParam {
code: String,
state: Option<String>,
}

async fn gitlab_oauth_handler(
State(state): State<OAuthState>,
Query(param): Query<GitlabOAuthQueryParam>,
) -> Redirect {
match_auth_result(
OAuthProvider::Gitlab,
state.oauth(param.code, OAuthProvider::Gitlab).await,
)
}

fn match_auth_result(
provider: OAuthProvider,
result: Result<OAuthResponse, OAuthError>,
Expand Down
1 change: 1 addition & 0 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ impl AuthenticationService for AuthenticationServiceImpl {
let url = match provider {
OAuthProvider::Github => external_url + "/oauth/callback/github",
OAuthProvider::Google => external_url + "/oauth/callback/google",
OAuthProvider::Gitlab => external_url + "/oauth/callback/gitlab",
};
Ok(url)
}
Expand Down

0 comments on commit 27557b5

Please sign in to comment.