Skip to content

Commit

Permalink
update openidconnect and oauth2
Browse files Browse the repository at this point in the history
  • Loading branch information
Kakadus committed Oct 19, 2024
1 parent f889f39 commit 6e3e4c1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ gloo-utils = "0.2"
js-sys = "0.3"
log = "0.4"
num-traits = "0.2"
oauth2 = "4"
reqwest = "0.11"
oauth2 = "5.0.0-rc.1"
reqwest = "0.12"
serde = { version = "1", features = ["derive"] }
time = { version = "0.3", features = ["wasm-bindgen"] }
tokio = { version = "1", features = ["sync"] }
Expand All @@ -32,7 +32,7 @@ web-sys = { version = "0.3", features = [
"Window",
] }

openidconnect = { version = "3.0", optional = true }
openidconnect = { version = "4.0.0-rc.1", optional = true, features = ["timing-resistant-secret-traits"] }
yew-nested-router = { version = "0.7.0", optional = true }

[features]
Expand Down
35 changes: 20 additions & 15 deletions src/agent/client/oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ use crate::{
};
use ::oauth2::{
basic::{BasicClient, BasicTokenResponse},
reqwest::async_http_client,
url::Url,
AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, PkceCodeVerifier,
RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
};

use ::oauth2::{EndpointNotSet, EndpointSet};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

type ExtendedBasicClient =
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LoginState {
pub pkce_verifier: String,
Expand All @@ -25,7 +29,8 @@ pub struct LoginState {
/// An OAuth2 based client implementation
#[derive(Clone, Debug)]
pub struct OAuth2Client {
client: BasicClient,
client: ExtendedBasicClient,
http_client: reqwest::Client,
}

impl OAuth2Client {
Expand Down Expand Up @@ -54,19 +59,19 @@ impl Client for OAuth2Client {
token_url,
} = config;

let client = BasicClient::new(
ClientId::new(client_id),
None,
AuthUrl::new(auth_url)
.map_err(|err| OAuth2Error::Configuration(format!("invalid auth URL: {err}")))?,
Some(
TokenUrl::new(token_url).map_err(|err| {
let client =
BasicClient::new(ClientId::new(client_id))
.set_auth_uri(AuthUrl::new(auth_url).map_err(|err| {
OAuth2Error::Configuration(format!("invalid auth URL: {err}"))
})?)
.set_token_uri(TokenUrl::new(token_url).map_err(|err| {
OAuth2Error::Configuration(format!("invalid token URL: {err}"))
})?,
),
);
})?);

Ok(Self { client })
Ok(Self {
client,
http_client: reqwest::Client::new(),
})
}

fn set_redirect_uri(mut self, url: Url) -> Self {
Expand Down Expand Up @@ -123,7 +128,7 @@ impl Client for OAuth2Client {
.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(pkce_verifier)
.request_async(async_http_client)
.request_async(&self.http_client)
.await
.map_err(|err| OAuth2Error::LoginResult(format!("failed to exchange code: {err}")))?;

Expand All @@ -140,7 +145,7 @@ impl Client for OAuth2Client {
let result = self
.client
.exchange_refresh_token(&RefreshToken::new(refresh_token))
.request_async(async_http_client)
.request_async(&self.http_client)
.await
.map_err(|err| {
OAuth2Error::Refresh(format!("failed to exchange refresh token: {err}"))
Expand Down
38 changes: 25 additions & 13 deletions src/agent/client/openid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,30 @@ use crate::{
};
use async_trait::async_trait;
use gloo_utils::window;
use oauth2::TokenResponse;
use oauth2::{EndpointMaybeSet, EndpointNotSet, EndpointSet, TokenResponse};
use openidconnect::{
core::{
CoreAuthDisplay, CoreAuthenticationFlow, CoreClaimName, CoreClaimType, CoreClient,
CoreClientAuthMethod, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType,
CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm,
CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreSubjectIdentifierType,
CoreTokenResponse,
CoreClientAuthMethod, CoreGenderClaim, CoreGrantType, CoreJsonWebKey,
CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, CoreResponseMode,
CoreResponseType, CoreSubjectIdentifierType, CoreTokenResponse,
},
reqwest::async_http_client,
AuthorizationCode, ClientId, CsrfToken, EmptyAdditionalClaims, IdTokenClaims, IssuerUrl, Nonce,
PkceCodeChallenge, PkceCodeVerifier, ProviderMetadata, RedirectUrl, RefreshToken, Scope,
};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use std::{fmt::Debug, rc::Rc};

type ExtendedCoreClient = CoreClient<
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OpenIdLoginState {
pub pkce_verifier: String,
Expand All @@ -37,7 +44,7 @@ const DEFAULT_POST_LOGOUT_DIRECT_NAME: &str = "post_logout_redirect_uri";
#[derive(Clone, Debug)]
pub struct OpenIdClient {
/// The client
client: CoreClient,
client: ExtendedCoreClient,
/// An override for the URL to end the session (logout)
end_session_url: Option<Url>,
/// A URL to direct to after the logout was performed
Expand All @@ -46,6 +53,7 @@ pub struct OpenIdClient {
post_logout_redirect_name: Option<String>,
/// Additional audiences of the ID token which are considered trustworthy
additional_trusted_audiences: Vec<String>,
http_client: reqwest::Client,
}

/// Additional metadata read from the discovery endpoint
Expand All @@ -66,9 +74,6 @@ pub type ExtendedProviderMetadata = ProviderMetadata<
CoreGrantType,
CoreJweContentEncryptionAlgorithm,
CoreJweKeyManagementAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
CoreJsonWebKeyUse,
CoreJsonWebKey,
CoreResponseMode,
CoreResponseType,
Expand All @@ -95,10 +100,12 @@ impl Client for OpenIdClient {
additional_trusted_audiences,
} = config;

let http_client = reqwest::Client::new();

let issuer = IssuerUrl::new(issuer_url)
.map_err(|err| OAuth2Error::Configuration(format!("invalid issuer URL: {err}")))?;

let metadata = ExtendedProviderMetadata::discover_async(issuer, async_http_client)
let metadata = ExtendedProviderMetadata::discover_async(issuer, &http_client)
.await
.map_err(|err| {
OAuth2Error::Configuration(format!("Failed to discover client: {err}"))
Expand All @@ -120,6 +127,7 @@ impl Client for OpenIdClient {
after_logout_url,
post_logout_redirect_name,
additional_trusted_audiences,
http_client,
})
}

Expand Down Expand Up @@ -176,8 +184,9 @@ impl Client for OpenIdClient {
let result = self
.client
.exchange_code(AuthorizationCode::new(code))
.map_err(|err| OAuth2Error::Configuration(format!("failed to exchange code: {err}")))?
.set_pkce_verifier(pkce_verifier)
.request_async(async_http_client)
.request_async(&self.http_client)
.await
.map_err(|err| OAuth2Error::LoginResult(format!("failed to exchange code: {err}")))?;

Expand Down Expand Up @@ -223,7 +232,10 @@ impl Client for OpenIdClient {
let result = self
.client
.exchange_refresh_token(&RefreshToken::new(refresh_token))
.request_async(async_http_client)
.map_err(|err| {
OAuth2Error::Configuration(format!("failed to exchange refresh token: {err}"))
})?
.request_async(&self.http_client)
.await
.map_err(|err| {
OAuth2Error::Refresh(format!("failed to exchange refresh token: {err}"))
Expand Down

0 comments on commit 6e3e4c1

Please sign in to comment.