From 042f768b5601513cf42173c297ba3eae04e5b86a Mon Sep 17 00:00:00 2001
From: Benjamin Sparks <b.sparks@alugha.com>
Date: Sat, 19 Oct 2024 13:31:36 +0200
Subject: [PATCH] Implement, test and export Scheme extractor

---
 axum-extra/Cargo.toml            |   1 +
 axum-extra/src/extract/host.rs   |   2 +-
 axum-extra/src/extract/mod.rs    |   7 ++
 axum-extra/src/extract/scheme.rs | 152 +++++++++++++++++++++++++++++++
 4 files changed, 161 insertions(+), 1 deletion(-)
 create mode 100644 axum-extra/src/extract/scheme.rs

diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml
index a35948da60..52072240c6 100644
--- a/axum-extra/Cargo.toml
+++ b/axum-extra/Cargo.toml
@@ -33,6 +33,7 @@ json-lines = [
 ]
 multipart = ["dep:multer", "dep:fastrand"]
 protobuf = ["dep:prost"]
+scheme = []
 query = ["dep:serde_html_form"]
 tracing = ["axum-core/tracing", "axum/tracing"]
 typed-header = ["dep:headers"]
diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs
index 477bc4fa94..981d7d71fd 100644
--- a/axum-extra/src/extract/host.rs
+++ b/axum-extra/src/extract/host.rs
@@ -139,7 +139,7 @@ mod tests {
         assert_eq!(value, "192.0.2.60");
 
         // is case insensitive
-        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
+        let headers = header_map(&[(FORWARDED, "HOST=192.0.2.60;proto=http;by=203.0.113.43")]);
         let value = parse_forwarded(&headers).unwrap();
         assert_eq!(value, "192.0.2.60");
 
diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs
index a0e710d1a5..b020fe6be2 100644
--- a/axum-extra/src/extract/mod.rs
+++ b/axum-extra/src/extract/mod.rs
@@ -56,3 +56,10 @@ pub use crate::json_lines::JsonLines;
 #[cfg(feature = "typed-header")]
 #[doc(no_inline)]
 pub use crate::typed_header::TypedHeader;
+
+#[cfg(feature = "scheme")]
+pub mod scheme;
+
+#[cfg(feature = "scheme")]
+#[doc(no_inline)]
+pub use self::scheme::{SchemeMissing, Scheme};
diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs
new file mode 100644
index 0000000000..891d5c0bdd
--- /dev/null
+++ b/axum-extra/src/extract/scheme.rs
@@ -0,0 +1,152 @@
+//! Extractor that parses the scheme of a request.
+//! See [`Scheme`] for more details.
+
+use axum::{
+    extract::FromRequestParts,
+    response::{IntoResponse, Response},
+};
+use http::{
+    header::{HeaderMap, FORWARDED},
+    request::Parts,
+};
+const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto";
+
+/// Extractor that resolves the scheme / protocol of a request.
+///
+/// The scheme is resolved through the following, in order:
+/// - `Forwarded` header
+/// - `X-Forwarded-Proto` header
+/// - Request URI (If the request is an HTTP/2 request! e.g. use `--http2(-prior-knowledge)` with cURL)
+///
+/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make
+/// sure to validate them to avoid security issues.
+#[derive(Debug, Clone)]
+pub struct Scheme(pub String);
+
+/// Rejection type used if the [`Scheme`] extractor is unable to
+/// resolve a scheme.
+#[derive(Debug)]
+pub struct SchemeMissing;
+
+impl IntoResponse for SchemeMissing {
+    fn into_response(self) -> Response {
+        (http::StatusCode::BAD_REQUEST, "No scheme found in request").into_response()
+    }
+}
+
+impl<S> FromRequestParts<S> for Scheme
+where
+    S: Send + Sync,
+{
+    type Rejection = SchemeMissing;
+
+    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+        // Within Forwarded header
+        if let Some(scheme) = parse_forwarded(&parts.headers) {
+            return Ok(Scheme(scheme.to_owned()));
+        }
+
+        // X-Forwarded-Proto
+        if let Some(scheme) = parts
+            .headers
+            .get(X_FORWARDED_PROTO_HEADER_KEY)
+            .and_then(|scheme| scheme.to_str().ok())
+        {
+            return Ok(Scheme(scheme.to_owned()));
+        }
+
+        // From parts of an HTTP/2 request
+        if let Some(scheme) = parts.uri.scheme_str() {
+            return Ok(Scheme(scheme.to_owned()));
+        }
+
+        Err(SchemeMissing)
+    }
+}
+
+fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
+    // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
+    let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
+
+    // get the first set of values
+    let first_value = forwarded_values.split(',').next()?;
+
+    // find the value of the `proto` field
+    first_value.split(';').find_map(|pair| {
+        let (key, value) = pair.split_once('=')?;
+        key.trim()
+            .eq_ignore_ascii_case("proto")
+            .then(|| value.trim().trim_matches('"'))
+    })
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test_helpers::TestClient;
+    use axum::{routing::get, Router};
+    use http::header::HeaderName;
+
+    fn test_client() -> TestClient {
+        async fn scheme_as_body(Scheme(scheme): Scheme) -> String {
+            scheme
+        }
+
+        TestClient::new(Router::new().route("/", get(scheme_as_body)))
+    }
+
+    #[crate::test]
+    async fn forwarded_scheme_parsing() {
+        // the basic case
+        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
+        let value = parse_forwarded(&headers).unwrap();
+        assert_eq!(value, "http");
+
+        // is case insensitive
+        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]);
+        let value = parse_forwarded(&headers).unwrap();
+        assert_eq!(value, "https");
+
+        // multiple values in one header
+        let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]);
+        let value = parse_forwarded(&headers).unwrap();
+        assert_eq!(value, "ftp");
+
+        // multiple header values
+        let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]);
+        let value = parse_forwarded(&headers).unwrap();
+        assert_eq!(value, "ftp");
+    }
+
+    #[crate::test]
+    async fn x_forwarded_scheme_header() {
+        let original_scheme = "https";
+        let scheme = test_client()
+            .get("/")
+            .header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme)
+            .await
+            .text()
+            .await;
+        assert_eq!(scheme, original_scheme);
+    }
+
+    #[crate::test]
+    async fn precedence_forwarded_over_x_forwarded() {
+        let scheme = test_client()
+            .get("/")
+            .header(X_FORWARDED_PROTO_HEADER_KEY, "https")
+            .header(FORWARDED, "proto=ftp")
+            .await
+            .text()
+            .await;
+        assert_eq!(scheme, "ftp");
+    }
+
+    fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
+        let mut headers = HeaderMap::new();
+        for (key, value) in values {
+            headers.append(key, value.parse().unwrap());
+        }
+        headers
+    }
+}