From 042f768b5601513cf42173c297ba3eae04e5b86a Mon Sep 17 00:00:00 2001 From: Benjamin Sparks <b.sparks@alugha.com> Date: Sat, 19 Oct 2024 13:31:36 +0200 Subject: [PATCH] Implement, test and export Scheme extractor --- axum-extra/Cargo.toml | 1 + axum-extra/src/extract/host.rs | 2 +- axum-extra/src/extract/mod.rs | 7 ++ axum-extra/src/extract/scheme.rs | 152 +++++++++++++++++++++++++++++++ 4 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 axum-extra/src/extract/scheme.rs diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index a35948da60..52072240c6 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -33,6 +33,7 @@ json-lines = [ ] multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] +scheme = [] query = ["dep:serde_html_form"] tracing = ["axum-core/tracing", "axum/tracing"] typed-header = ["dep:headers"] diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index 477bc4fa94..981d7d71fd 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -139,7 +139,7 @@ mod tests { assert_eq!(value, "192.0.2.60"); // is case insensitive - let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); + let headers = header_map(&[(FORWARDED, "HOST=192.0.2.60;proto=http;by=203.0.113.43")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index a0e710d1a5..b020fe6be2 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -56,3 +56,10 @@ pub use crate::json_lines::JsonLines; #[cfg(feature = "typed-header")] #[doc(no_inline)] pub use crate::typed_header::TypedHeader; + +#[cfg(feature = "scheme")] +pub mod scheme; + +#[cfg(feature = "scheme")] +#[doc(no_inline)] +pub use self::scheme::{SchemeMissing, Scheme}; diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs new file mode 100644 index 0000000000..891d5c0bdd --- /dev/null +++ b/axum-extra/src/extract/scheme.rs @@ -0,0 +1,152 @@ +//! Extractor that parses the scheme of a request. +//! See [`Scheme`] for more details. + +use axum::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use http::{ + header::{HeaderMap, FORWARDED}, + request::Parts, +}; +const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto"; + +/// Extractor that resolves the scheme / protocol of a request. +/// +/// The scheme is resolved through the following, in order: +/// - `Forwarded` header +/// - `X-Forwarded-Proto` header +/// - Request URI (If the request is an HTTP/2 request! e.g. use `--http2(-prior-knowledge)` with cURL) +/// +/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make +/// sure to validate them to avoid security issues. +#[derive(Debug, Clone)] +pub struct Scheme(pub String); + +/// Rejection type used if the [`Scheme`] extractor is unable to +/// resolve a scheme. +#[derive(Debug)] +pub struct SchemeMissing; + +impl IntoResponse for SchemeMissing { + fn into_response(self) -> Response { + (http::StatusCode::BAD_REQUEST, "No scheme found in request").into_response() + } +} + +impl<S> FromRequestParts<S> for Scheme +where + S: Send + Sync, +{ + type Rejection = SchemeMissing; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + // Within Forwarded header + if let Some(scheme) = parse_forwarded(&parts.headers) { + return Ok(Scheme(scheme.to_owned())); + } + + // X-Forwarded-Proto + if let Some(scheme) = parts + .headers + .get(X_FORWARDED_PROTO_HEADER_KEY) + .and_then(|scheme| scheme.to_str().ok()) + { + return Ok(Scheme(scheme.to_owned())); + } + + // From parts of an HTTP/2 request + if let Some(scheme) = parts.uri.scheme_str() { + return Ok(Scheme(scheme.to_owned())); + } + + Err(SchemeMissing) + } +} + +fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { + // if there are multiple `Forwarded` `HeaderMap::get` will return the first one + let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?; + + // get the first set of values + let first_value = forwarded_values.split(',').next()?; + + // find the value of the `proto` field + first_value.split(';').find_map(|pair| { + let (key, value) = pair.split_once('=')?; + key.trim() + .eq_ignore_ascii_case("proto") + .then(|| value.trim().trim_matches('"')) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::TestClient; + use axum::{routing::get, Router}; + use http::header::HeaderName; + + fn test_client() -> TestClient { + async fn scheme_as_body(Scheme(scheme): Scheme) -> String { + scheme + } + + TestClient::new(Router::new().route("/", get(scheme_as_body))) + } + + #[crate::test] + async fn forwarded_scheme_parsing() { + // the basic case + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "http"); + + // is case insensitive + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "https"); + + // multiple values in one header + let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + + // multiple header values + let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + } + + #[crate::test] + async fn x_forwarded_scheme_header() { + let original_scheme = "https"; + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme) + .await + .text() + .await; + assert_eq!(scheme, original_scheme); + } + + #[crate::test] + async fn precedence_forwarded_over_x_forwarded() { + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, "https") + .header(FORWARDED, "proto=ftp") + .await + .text() + .await; + assert_eq!(scheme, "ftp"); + } + + fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap { + let mut headers = HeaderMap::new(); + for (key, value) in values { + headers.append(key, value.parse().unwrap()); + } + headers + } +}