diff --git a/CustomIdentityComponent/lambda_rust/Cargo.toml b/CustomIdentityComponent/lambda_rust/Cargo.toml new file mode 100644 index 0000000..203d5aa --- /dev/null +++ b/CustomIdentityComponent/lambda_rust/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "refresh-access-token" +version = "0.1.0" +edition = "2021" + +[dependencies] +aws-config = "0" +aws-sdk-secretsmanager = "0" +cached = { version ="0", features = ["async"] } +http = "0.2" +jsonwebkey = { version ="0.3", features = ["jwt-convert"] } +jsonwebtoken = "8" +lambda_http = { version = "0.8", default-features = false, features = ["apigw_http"] } +lambda_runtime = "0.8" +metrics = "0.21.1" +metrics_cloudwatch_embedded = { version = "0.4.1", features = ["lambda"] } +reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"] } +reqwest-middleware = "0.2" +reqwest-retry = "0.3" +serde = {version = "1.0", features = ["derive"] } +serde_json = "1.0" +time = "0.3" +tokio = { version = "1", features = ["macros"] } +tracing = { version = "0.1", features = ["log"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "env-filter", "json"] } + +[profile.release] +opt-level = "z" +lto = true +codegen-units = 1 +panic = "abort" \ No newline at end of file diff --git a/CustomIdentityComponent/lambda_rust/src/main.rs b/CustomIdentityComponent/lambda_rust/src/main.rs new file mode 100644 index 0000000..a27ee6d --- /dev/null +++ b/CustomIdentityComponent/lambda_rust/src/main.rs @@ -0,0 +1,257 @@ +use cached::proc_macro::{cached, once}; +use lambda_http::{Body, Error, Request, RequestExt, Response}; +use metrics_cloudwatch_embedded::lambda::handler::run_http; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tracing::{debug, info, info_span}; + +/// Input Jwt token claims +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct ClaimsIn { + sub: String, + iss: String, + kid: String, + aud: String, + scope: String, + access_token_scope: Option, + iat: i64, + nbf: i64, + exp: i64, +} + +/// Output Jwt token claims with references to save some allocations +#[derive(Debug, Serialize)] +struct ClaimsOut<'a> { + sub: &'a str, + iss: &'a str, + kid: &'a str, + aud: &'a str, + scope: &'a str, + access_token_scope: Option<&'a str>, + iat: i64, + nbf: i64, + exp: i64, +} + +/// Json body of success responses +#[derive(Debug, Serialize)] +struct ResponsePayload<'a> { + user_id: &'a str, + auth_token: &'a str, + refresh_token: &'a str, + auth_token_expires_in: i64, + refresh_token_expires_in: i64, +} + +fn generate_response(code: u16, body: &str) -> Response { + Response::builder() + .status(code) + .header("Access-Control-Allow-Origin", "*") + .header("Access-Control-Allow-Credentials", "true") + .body(body.into()) + .expect("failed to generate response") +} + +#[cached] +/// get our (cached) aws configuration +async fn get_aws_config() -> Arc { + Arc::new(aws_config::load_from_env().await) +} + +#[cached(time = 900)] +/// get our private kid and key from secrects manager, panic on failure +async fn get_private_key() -> (Arc, Arc) { + info!("refreshing private key from Secrets Manager"); + + let aws_config = get_aws_config().await; + let secrets_client = aws_sdk_secretsmanager::Client::new(&aws_config); + + let jwk: jsonwebkey::JsonWebKey = secrets_client + .get_secret_value() + .secret_id(std::env::var("SECRET_KEY_ID").unwrap()) + .send() + .await + .expect("failed to get SECRET_KEY_ID") + .secret_string() + .expect("SECRET_KEY_ID is blank") + .to_string() + .parse() + .expect("private key is not a valid jwk"); + + ( + Arc::new(jwk.key_id.unwrap()), + Arc::new(jsonwebtoken::EncodingKey::from_rsa_pem(jwk.key.to_pem().as_bytes()).unwrap()), + ) +} + +#[once(time = 900)] +/// get the json web keyset for our issuer, panic on failure +async fn get_keyset(issuer: &str) -> Arc> { + info!("Refreshing json web keyset"); + + use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; + + let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3); + let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new()) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + + let jwks = client + .get(format!("{issuer}/.well-known/jwks.json")) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + + let mut dict = HashMap::new(); + for jwk in jwks.keys { + if let (Some(key_id), jsonwebtoken::jwk::AlgorithmParameters::RSA(rsa)) = + (jwk.common.key_id, &jwk.algorithm) + { + dict.insert( + key_id, + jsonwebtoken::DecodingKey::from_rsa_components(&rsa.n, &rsa.e).unwrap(), + ); + } + } + + if dict.is_empty() { + panic!("jwks has no valid keys"); + } + + Arc::new(dict) +} + +async fn process_token(issuer: &str, refresh_token: &str) -> Result, Error> { + let header = jsonwebtoken::decode_header(refresh_token)?; + let kid = header.kid.ok_or("kid missing from jwt header")?; + + let jks = get_keyset(issuer).await; + let public_key = jks.get(&kid).ok_or("kid not in jks")?; + + let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256); + validation.set_audience(&["refresh"]); + validation.set_issuer(&[issuer]); + + let jwt = jsonwebtoken::decode::(refresh_token, public_key, &validation)?; + debug!("jwt = {jwt:?}"); + + let user_id = jwt.claims.sub.as_str(); + let access_token_scope = &jwt + .claims + .access_token_scope + .ok_or("missing access_token_scope claim")?; + let access_token_duration_sec = 15 * 60; + let existing_exp_value = jwt.claims.exp; + + let (private_kid, private_key) = get_private_key().await; + + // Build a new header with the latest kid + let mut new_header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256); + new_header.kid = Some(private_kid.to_string()); + + let now = time::OffsetDateTime::now_utc().unix_timestamp(); + + // Build a new refresh token + let refresh_claims = ClaimsOut { + sub: user_id, + iss: issuer, + kid: &private_kid, + aud: "refresh", + scope: "refresh", + access_token_scope: Some(access_token_scope), + iat: now, + nbf: now, + exp: existing_exp_value, + }; + let refresh_token = jsonwebtoken::encode(&new_header, &refresh_claims, &private_key)?; + + // Build a new access token + let access_claims = ClaimsOut { + sub: user_id, + iss: issuer, + kid: &private_kid, + aud: "gamebackend", + scope: access_token_scope, + access_token_scope: None, + iat: now, + nbf: now, + exp: now + access_token_duration_sec, + }; + let access_token = jsonwebtoken::encode(&new_header, &access_claims, &private_key)?; + + let response_payload = ResponsePayload { + user_id, + auth_token: &access_token, + auth_token_expires_in: access_token_duration_sec, + refresh_token: &refresh_token, + refresh_token_expires_in: existing_exp_value - now, + }; + + Ok(generate_response( + 200, + &serde_json::to_string(&response_payload)?, + )) +} + +async fn function_handler(issuer: &str, request: Request) -> Result, Error> { + // Get the refresh_token from the query string + let query = request.query_string_parameters(); + let refresh_token = query.first("refresh_token"); + + match refresh_token { + None => { + metrics::increment_counter!("deny", "reason" => "No refresh token provided"); + Ok(generate_response(401, "Error: No refresh token provided")) + } + Some(refresh_token) => match process_token(issuer, refresh_token).await { + Ok(response) => { + metrics::increment_counter!("allow"); + Ok(response) + } + Err(e) => { + // Record the details but don't give the remote client specifics + metrics::increment_counter!("deny", "reason" => e.to_string()); + Ok(generate_response( + 401, + "Error: Failed to validate refresh token", + )) + } + }, + } +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + tracing_subscriber::fmt() + .json() + .with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env()) + .with_target(false) + .with_current_span(false) + .without_time() + .init(); + + let issuer = std::env::var("ISSUER_URL").unwrap(); + + let metrics = metrics_cloudwatch_embedded::Builder::new() + .cloudwatch_namespace(std::env::var("POWERTOOLS_METRICS_NAMESPACE").unwrap()) + .with_dimension("service", std::env::var("POWERTOOLS_SERVICE_NAME").unwrap()) + .with_dimension( + "function", + std::env::var("AWS_LAMBDA_FUNCTION_NAME").unwrap(), + ) + .lambda_cold_start_span(info_span!("cold start").entered()) + .lambda_cold_start_metric("ColdStart") + .with_lambda_request_id("requestId") + .init() + .unwrap(); + + run_http(metrics, |request: Request| { + function_handler(&issuer, request) + }) + .await +} diff --git a/CustomIdentityComponent/lib/custom_identity_component-stack.ts b/CustomIdentityComponent/lib/custom_identity_component-stack.ts index cd7beac..7f460c9 100644 --- a/CustomIdentityComponent/lib/custom_identity_component-stack.ts +++ b/CustomIdentityComponent/lib/custom_identity_component-stack.ts @@ -3,6 +3,7 @@ import { Stack, StackProps, CfnOutput, Duration } from 'aws-cdk-lib'; import { Construct } from 'constructs'; +import { RustFunction } from 'cargo-lambda-cdk'; import * as lambda from 'aws-cdk-lib/aws-lambda'; import * as cloudfront from 'aws-cdk-lib/aws-cloudfront'; import * as s3 from 'aws-cdk-lib/aws-s3'; @@ -362,8 +363,38 @@ export class CustomIdentityComponentStack extends Stack { { id: 'AwsSolutions-IAM5', reason: 'Using the standard Lambda execution role, all custom access resource restricted.' } ], true); - // Map login_as_guest_function to the api_gateway GET requeste login_as_guest - api_gateway.root.addResource('refresh-access-token').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_function),{ + // Map refresh_access_token_function to the api_gateway GET requeste refresh_access_token_function + api_gateway.root.addResource('refresh-access-token').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_function), { + requestParameters: { + 'method.request.querystring.refresh_token': true + }, + requestValidator: requestValidator + }); + + const refresh_access_token_rust_function = new RustFunction(this, 'RefreshAccessToken_Rust', { + role: refresh_access_token_function_role, + manifestPath: 'lambda_rust/Cargo.toml', + architecture: lambda.Architecture.ARM_64, + bundling: { + forcedDockerBundling: true, + }, + timeout: Duration.seconds(5), + tracing: lambda.Tracing.ACTIVE, + memorySize: 256, + environment: { + "ISSUER_URL": "https://" + distribution.domainName, + "POWERTOOLS_METRICS_NAMESPACE": "AWS for Games", + "POWERTOOLS_SERVICE_NAME": "CustomIdentityComponent", + "RUST_LOG": "info", + "SECRET_KEY_ID": secret.secretName, + "USER_TABLE": user_table.tableName + } + }); + secret.grantRead(refresh_access_token_rust_function); + user_table.grantReadWriteData(refresh_access_token_rust_function); + + // Map refresh_access_token_function to the api_gateway GET requeste refresh_access_token_function + api_gateway.root.addResource('refresh-access-token-rust').addMethod('GET', new apigw.LambdaIntegration(refresh_access_token_rust_function), { requestParameters: { 'method.request.querystring.refresh_token': true }, diff --git a/CustomIdentityComponent/package.json b/CustomIdentityComponent/package.json index 458e93d..16d5f88 100644 --- a/CustomIdentityComponent/package.json +++ b/CustomIdentityComponent/package.json @@ -21,6 +21,7 @@ }, "dependencies": { "aws-cdk-lib": "^2.97.0", + "cargo-lambda-cdk": "^0.0.16", "cdk": "^2.81.0-alpha.0", "cdk-nag": "^2.27.24", "constructs": "^10.0.0",