diff --git a/core/Cargo.toml b/core/Cargo.toml index 31690b12ed..51c500993b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -30,6 +30,9 @@ parking_lot = { version = "0.12", optional = true } tokio = { version = "1.16", optional = true } wasm-bindgen-futures = { version = "0.4.19", optional = true } futures-timer = { version = "3", optional = true } +globset = { version = "0.4", optional = true } +lazy_static = { version = "1", optional = true } +unicase = { version = "2.6.0", optional = true } [features] default = [] @@ -37,12 +40,15 @@ http-helpers = ["hyper", "futures-util"] server = [ "arrayvec", "futures-util/alloc", + "globset", "rustc-hash/std", "tracing", "parking_lot", "rand", "tokio/rt", "tokio/sync", + "lazy_static", + "unicase", ] client = ["futures-util/sink", "futures-channel/sink", "futures-channel/std"] async-client = [ diff --git a/core/src/error.rs b/core/src/error.rs index d50d13102c..2334011964 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -105,9 +105,12 @@ pub enum Error { /// Attempted to stop server that is already stopped. #[error("Attempted to stop server that is already stopped")] AlreadyStopped, - /// List passed into `set_allowed_origins` was empty + /// List passed into access control based on HTTP header verification. #[error("Must set at least one allowed value for the {0} header")] EmptyAllowList(&'static str), + /// Access control verification of HTTP headers failed. + #[error("HTTP header: `{0}` value: `{1}` verification failed")] + HttpHeaderRejected(&'static str, String), /// Failed to execute a method because a resource was already at capacity #[error("Resource at capacity: {0}")] ResourceAtCapacity(&'static str), diff --git a/core/src/http_helpers.rs b/core/src/http_helpers.rs index a8daaea3a7..469d0acac0 100644 --- a/core/src/http_helpers.rs +++ b/core/src/http_helpers.rs @@ -99,13 +99,25 @@ pub fn read_header_value<'a>(headers: &'a hyper::header::HeaderMap, header_name: pub fn read_header_values<'a>( headers: &'a hyper::header::HeaderMap, header_name: &str, -) -> hyper::header::ValueIter<'a, hyper::header::HeaderValue> { - headers.get_all(header_name).iter() +) -> hyper::header::GetAll<'a, hyper::header::HeaderValue> { + headers.get_all(header_name) +} + +/// Get the header values from the `access-control-request-headers` header. +pub fn get_cors_request_headers<'a>(headers: &'a hyper::header::HeaderMap) -> impl Iterator { + const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers"; + + read_header_values(headers, ACCESS_CONTROL_REQUEST_HEADERS) + .iter() + .filter_map(|val| val.to_str().ok()) + .flat_map(|val| val.split(",")) + // The strings themselves might contain leading and trailing whitespaces + .map(|s| s.trim()) } #[cfg(test)] mod tests { - use super::{read_body, read_header_content_length}; + use super::{get_cors_request_headers, read_body, read_header_content_length}; #[tokio::test] async fn body_to_bytes_size_limit_works() { @@ -130,4 +142,23 @@ mod tests { headers.insert(hyper::header::CONTENT_LENGTH, "18446744073709551616".parse().unwrap()); assert_eq!(read_header_content_length(&headers), None); } + + #[test] + fn get_cors_headers_works() { + let mut headers = hyper::header::HeaderMap::new(); + + // access-control-request-headers + headers.insert(hyper::header::ACCESS_CONTROL_REQUEST_HEADERS, "Content-Type,x-requested-with".parse().unwrap()); + + let values: Vec<&str> = get_cors_request_headers(&headers).collect(); + assert_eq!(values, vec!["Content-Type", "x-requested-with"]); + + headers.insert( + hyper::header::ACCESS_CONTROL_REQUEST_HEADERS, + "Content-Type, x-requested-with ".parse().unwrap(), + ); + + let values: Vec<&str> = get_cors_request_headers(&headers).collect(); + assert_eq!(values, vec!["Content-Type", "x-requested-with"]); + } } diff --git a/http-server/src/access_control/cors.rs b/core/src/server/access_control/cors.rs similarity index 74% rename from http-server/src/access_control/cors.rs rename to core/src/server/access_control/cors.rs index ba807743bf..4f428ba128 100644 --- a/http-server/src/access_control/cors.rs +++ b/core/src/server/access_control/cors.rs @@ -29,9 +29,9 @@ use std::collections::HashSet; use std::{fmt, ops}; -use crate::access_control::hosts::{Host, Port}; -use crate::access_control::matcher::{Matcher, Pattern}; -use jsonrpsee_core::Cow; +use crate::server::access_control::host::{Host, Port}; +use crate::server::access_control::matcher::{Matcher, Pattern}; +use crate::Cow; use lazy_static::lazy_static; use unicase::Ascii; @@ -128,54 +128,54 @@ impl ops::Deref for Origin { /// Origins allowed to access #[derive(Debug, Clone, PartialEq, Eq)] -pub enum AccessControlAllowOrigin { - /// Specific hostname - Value(Origin), +pub enum AllowOrigin { + /// Specific origin. + Origin(Origin), /// null-origin (file:///, sandboxed iframe) Null, /// Any non-null origin Any, } -impl fmt::Display for AccessControlAllowOrigin { +impl fmt::Display for AllowOrigin { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "{}", match *self { - AccessControlAllowOrigin::Any => "*", - AccessControlAllowOrigin::Null => "null", - AccessControlAllowOrigin::Value(ref val) => val, + Self::Any => "*", + Self::Null => "null", + Self::Origin(ref val) => val, } ) } } -impl> From for AccessControlAllowOrigin { - fn from(s: T) -> AccessControlAllowOrigin { +impl> From for AllowOrigin { + fn from(s: T) -> Self { match s.into().as_str() { - "all" | "*" | "any" => AccessControlAllowOrigin::Any, - "null" => AccessControlAllowOrigin::Null, - origin => AccessControlAllowOrigin::Value(origin.into()), + "all" | "*" | "any" => Self::Any, + "null" => Self::Null, + origin => Self::Origin(origin.into()), } } } /// Headers allowed to access #[derive(Debug, Clone, PartialEq)] -pub enum AccessControlAllowHeaders { +pub enum AllowHeaders { /// Specific headers Only(Vec), /// Any header Any, } -impl AccessControlAllowHeaders { +impl AllowHeaders { /// Return an appropriate value for the CORS header "Access-Control-Allow-Headers". pub fn to_cors_header_value(&self) -> Cow<'_, str> { match self { - AccessControlAllowHeaders::Any => "*".into(), - AccessControlAllowHeaders::Only(headers) => headers.join(", ").into(), + AllowHeaders::Any => "*".into(), + AllowHeaders::Only(headers) => headers.join(", ").into(), } } } @@ -219,11 +219,11 @@ impl From> for Option { } /// Returns correct CORS header (if any) given list of allowed origins and current origin. -pub(crate) fn get_cors_allow_origin( +pub(super) fn get_cors_allow_origin( origin: Option<&str>, + allowed: &Option>, host: Option<&str>, - allowed: &Option>, -) -> AllowCors { +) -> AllowCors { match origin { None => AllowCors::NotRequired, Some(ref origin) => { @@ -239,22 +239,22 @@ pub(crate) fn get_cors_allow_origin( } match allowed.as_ref() { - None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null), - None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))), + None if *origin == "null" => AllowCors::Ok(AllowOrigin::Null), + None => AllowCors::Ok(AllowOrigin::Origin(Origin::parse(origin))), Some(allowed) if *origin == "null" => allowed .iter() - .find(|cors| **cors == AccessControlAllowOrigin::Null) + .find(|cors| **cors == AllowOrigin::Null) .cloned() .map(AllowCors::Ok) .unwrap_or(AllowCors::Invalid), Some(allowed) => allowed .iter() .find(|cors| match **cors { - AccessControlAllowOrigin::Any => true, - AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true, + AllowOrigin::Any => true, + AllowOrigin::Origin(ref val) if val.matches(origin) => true, _ => false, }) - .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin))) + .map(|_| AllowOrigin::Origin(Origin::parse(origin))) .map(AllowCors::Ok) .unwrap_or(AllowCors::Invalid), } @@ -262,15 +262,19 @@ pub(crate) fn get_cors_allow_origin( } } -/// Validates if the `AccessControlAllowedHeaders` in the request are allowed. +/// Validates if the headers in the request are allowed. +/// +/// headers: all the headers in the request. +/// cors_request_headers: `values` in the `access-control-request-headers` header. +/// cors_allow_headers: whitelisted headers by the user. pub(crate) fn get_cors_allow_headers, O, F: Fn(T) -> O>( mut headers: impl Iterator, - requested_headers: impl Iterator, - cors_allow_headers: &AccessControlAllowHeaders, + cors_request_headers: impl Iterator, + cors_allow_headers: &AllowHeaders, to_result: F, ) -> AllowCors> { // Check if the header fields which were sent in the request are allowed - if let AccessControlAllowHeaders::Only(only) = cors_allow_headers { + if let AllowHeaders::Only(only) = cors_allow_headers { let are_all_allowed = headers.all(|header| { let name = &Ascii::new(header.as_ref()); only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) @@ -283,13 +287,13 @@ pub(crate) fn get_cors_allow_headers, O, F: Fn(T) -> O>( // Check if `AccessControlRequestHeaders` contains fields which were allowed let (filtered, headers) = match cors_allow_headers { - AccessControlAllowHeaders::Any => { - let headers = requested_headers.map(to_result).collect(); + AllowHeaders::Any => { + let headers = cors_request_headers.map(to_result).collect(); (false, headers) } - AccessControlAllowHeaders::Only(only) => { + AllowHeaders::Only(only) => { let mut filtered = false; - let headers: Vec<_> = requested_headers + let headers: Vec<_> = cors_request_headers .filter(|header| { let name = &Ascii::new(header.as_ref()); filtered = true; @@ -319,7 +323,6 @@ lazy_static! { let mut hs = HashSet::new(); hs.insert(Ascii::new("Accept")); hs.insert(Ascii::new("Accept-Language")); - hs.insert(Ascii::new("Access-Control-Allow-Origin")); hs.insert(Ascii::new("Access-Control-Request-Headers")); hs.insert(Ascii::new("Content-Language")); hs.insert(Ascii::new("Content-Type")); @@ -337,7 +340,7 @@ mod tests { use std::iter; use super::*; - use crate::access_control::hosts::Host; + use crate::server::access_control::host::Host; #[test] fn should_parse_origin() { @@ -365,8 +368,8 @@ mod tests { let host = Some(&*host); // when - let res1 = get_cors_allow_origin(origin1, host, &Some(vec![])); - let res2 = get_cors_allow_origin(origin2, host, &Some(vec![])); + let res1 = get_cors_allow_origin(origin1, &Some(vec![]), host); + let res2 = get_cors_allow_origin(origin2, &Some(vec![]), host); // then assert_eq!(res1, AllowCors::Invalid); @@ -383,7 +386,7 @@ mod tests { let host = Some(&*host); // when - let res = get_cors_allow_origin(origin, host, &None); + let res = get_cors_allow_origin(origin, &None, host); // then assert_eq!(res, AllowCors::NotRequired); @@ -396,7 +399,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &None); + let res = get_cors_allow_origin(origin, &None, host); // then assert_eq!(res, AllowCors::NotRequired); @@ -409,7 +412,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &None); + let res = get_cors_allow_origin(origin, &None, host); // then assert_eq!(res, AllowCors::Ok("parity.io".into())); @@ -422,11 +425,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]), - ); + let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Origin("http://ethereum.org".into())]), host); // then assert_eq!(res, AllowCors::NotRequired); @@ -439,7 +438,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &Some(Vec::new())); + let res = get_cors_allow_origin(origin, &Some(Vec::new()), host); // then assert_eq!(res, AllowCors::NotRequired); @@ -452,11 +451,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]), - ); + let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Origin("http://ethereum.org".into())]), host); // then assert_eq!(res, AllowCors::Invalid); @@ -469,10 +464,10 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any])); + let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Any]), host); // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!(res, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into()))); } #[test] @@ -482,7 +477,7 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); + let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Null]), host); // then assert_eq!(res, AllowCors::NotRequired); @@ -495,10 +490,10 @@ mod tests { let host = None; // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); + let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Null]), host); // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null)); + assert_eq!(res, AllowCors::Ok(AllowOrigin::Null)); } #[test] @@ -510,15 +505,15 @@ mod tests { // when let res = get_cors_allow_origin( origin, - host, &Some(vec![ - AccessControlAllowOrigin::Value("http://ethereum.org".into()), - AccessControlAllowOrigin::Value("http://parity.io".into()), + AllowOrigin::Origin("http://ethereum.org".into()), + AllowOrigin::Origin("http://parity.io".into()), ]), + host, ); // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!(res, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into()))); } #[test] @@ -528,26 +523,24 @@ mod tests { let origin2 = Some("http://parity.iot"); let origin3 = Some("chrome-extension://test"); let host = None; - let allowed = Some(vec![ - AccessControlAllowOrigin::Value("http://*.io".into()), - AccessControlAllowOrigin::Value("chrome-extension://*".into()), - ]); + let allowed = + Some(vec![AllowOrigin::Origin("http://*.io".into()), AllowOrigin::Origin("chrome-extension://*".into())]); // when - let res1 = get_cors_allow_origin(origin1, host, &allowed); - let res2 = get_cors_allow_origin(origin2, host, &allowed); - let res3 = get_cors_allow_origin(origin3, host, &allowed); + let res1 = get_cors_allow_origin(origin1, &allowed, host); + let res2 = get_cors_allow_origin(origin2, &allowed, host); + let res3 = get_cors_allow_origin(origin3, &allowed, host); // then - assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!(res1, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into()))); assert_eq!(res2, AllowCors::Invalid); - assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))); + assert_eq!(res3, AllowCors::Ok(AllowOrigin::Origin("chrome-extension://test".into()))); } #[test] fn should_return_invalid_if_header_not_allowed() { // given - let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]); + let cors_allow_headers = AllowHeaders::Only(vec!["x-allowed".to_owned()]); let headers = vec!["Access-Control-Request-Headers"]; let requested = vec!["x-not-allowed"]; @@ -562,7 +555,7 @@ mod tests { fn should_return_valid_if_header_allowed() { // given let allowed = vec!["x-allowed".to_owned()]; - let cors_allow_headers = AccessControlAllowHeaders::Only(allowed); + let cors_allow_headers = AllowHeaders::Only(allowed); let headers = vec!["Access-Control-Request-Headers"]; let requested = vec!["x-allowed"]; @@ -578,7 +571,7 @@ mod tests { fn should_return_no_allowed_headers_if_none_in_request() { // given let allowed = vec!["x-allowed".to_owned()]; - let cors_allow_headers = AccessControlAllowHeaders::Only(allowed); + let cors_allow_headers = AllowHeaders::Only(allowed); let headers: Vec = vec![]; // when @@ -591,7 +584,7 @@ mod tests { #[test] fn should_return_not_required_if_any_header_allowed() { // given - let cors_allow_headers = AccessControlAllowHeaders::Any; + let cors_allow_headers = AllowHeaders::Any; let headers: Vec = vec![]; // when diff --git a/http-server/src/access_control/hosts.rs b/core/src/server/access_control/host.rs similarity index 70% rename from http-server/src/access_control/hosts.rs rename to core/src/server/access_control/host.rs index 205a43b1ac..30dbc5178f 100644 --- a/http-server/src/access_control/hosts.rs +++ b/core/src/server/access_control/host.rs @@ -26,7 +26,8 @@ //! Host header validation. -use crate::access_control::matcher::{Matcher, Pattern}; +use crate::server::access_control::matcher::{Matcher, Pattern}; +use crate::Error; const SPLIT_PROOF: &str = "split always returns non-empty iterator."; @@ -139,47 +140,31 @@ impl std::ops::Deref for Host { } } -/// Specifies if domains should be validated. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum DomainsValidation { - /// Allow only domains on the list. - AllowOnly(Vec), - /// Disable domains validation completely. - Disabled, +/// Policy for validating the `HTTP host header`. +#[derive(Debug, Clone)] +pub enum AllowHosts { + /// Allow all hosts (no filter). + Any, + /// Allow only specified hosts. + Only(Vec), } -impl From>> for DomainsValidation { - fn from(other: Option>) -> Self { - match other { - Some(list) => DomainsValidation::AllowOnly(list), - None => DomainsValidation::Disabled, +impl AllowHosts { + /// Verify a host. + pub fn verify(&self, value: &str) -> Result<(), Error> { + if let AllowHosts::Only(list) = self { + if !list.iter().any(|o| o.matches(value)) { + return Err(Error::HttpHeaderRejected("host", value.into())); + } } - } -} -/// Returns `true` when `Host` header is whitelisted in `allow_hosts`. -pub(crate) fn is_host_valid(host: Option<&str>, allow_hosts: &AllowHosts) -> bool { - match host { - None => false, - Some(ref host) => match allow_hosts { - AllowHosts::Any => true, - AllowHosts::Only(allow_hosts) => allow_hosts.iter().any(|h| h.matches(host)), - }, + Ok(()) } } -/// Allowed hosts for http header 'host' -#[derive(Clone, Debug)] -pub enum AllowHosts { - /// Allow requests from any host - Any, - /// Allow only a selection of specific hosts - Only(Vec), -} - #[cfg(test)] mod tests { - use super::{is_host_valid, AllowHosts, Host}; + use super::{AllowHosts, Host, Port}; #[test] fn should_parse_host() { @@ -188,43 +173,35 @@ mod tests { assert_eq!(Host::parse("chrome-extension://124.0.0.1"), Host::new("124.0.0.1", None)); assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None)); assert_eq!(Host::parse("127.0.0.1:8545/somepath"), Host::new("127.0.0.1", Some(8545))); - } - #[test] - fn should_reject_when_there_is_no_header() { - let valid = is_host_valid(None, &AllowHosts::Any); - assert!(!valid); - let valid = is_host_valid(None, &AllowHosts::Only(vec![])); - assert!(!valid); + let host = Host::parse("*.domain:*/somepath"); + assert_eq!(host.port, Port::Pattern("*".into())); + assert_eq!(host.hostname.as_str(), "*.domain"); } #[test] - fn should_reject_when_validation_is_disabled() { - let valid = is_host_valid(Some("any"), &AllowHosts::Any); - assert!(valid); + fn should_allow_when_validation_is_disabled() { + assert!((AllowHosts::Any).verify("any").is_ok()); } #[test] fn should_reject_if_header_not_on_the_list() { - let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec![])); - assert!(!valid); + assert!((AllowHosts::Only(vec![].into())).verify("parity.io").is_err()); } #[test] fn should_accept_if_on_the_list() { - let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec!["parity.io".into()])); - assert!(valid); + assert!((AllowHosts::Only(vec!["parity.io".into()].into())).verify("parity.io").is_ok()); } #[test] fn should_accept_if_on_the_list_with_port() { - let valid = is_host_valid(Some("parity.io:443"), &AllowHosts::Only(vec!["parity.io:443".into()])); - assert!(valid); + assert!((AllowHosts::Only(vec!["parity.io:443".into()].into())).verify("parity.io:443").is_ok()); + assert!((AllowHosts::Only(vec!["parity.io".into()].into())).verify("parity.io:443").is_err()); } #[test] fn should_support_wildcards() { - let valid = is_host_valid(Some("parity.web3.site:8180"), &AllowHosts::Only(vec!["*.web3.site:*".into()])); - assert!(valid); + assert!((AllowHosts::Only(vec!["*.web3.site:*".into()].into())).verify("parity.web3.site:8180").is_ok()); } } diff --git a/http-server/src/access_control/matcher.rs b/core/src/server/access_control/matcher.rs similarity index 100% rename from http-server/src/access_control/matcher.rs rename to core/src/server/access_control/matcher.rs diff --git a/core/src/server/access_control/mod.rs b/core/src/server/access_control/mod.rs new file mode 100644 index 0000000000..5796a83455 --- /dev/null +++ b/core/src/server/access_control/mod.rs @@ -0,0 +1,172 @@ +//! Access control based on HTTP headers + +pub mod cors; +pub mod host; +mod matcher; + +pub use cors::{AllowHeaders, AllowOrigin, Origin}; +pub use host::{AllowHosts, Host}; + +use crate::Error; + +use self::cors::get_cors_allow_origin; + +/// Define access on control on HTTP layer. +#[derive(Clone, Debug)] +pub struct AccessControl { + allowed_hosts: AllowHosts, + allowed_origins: Option>, + allowed_headers: AllowHeaders, +} + +impl AccessControl { + /// Validate incoming request by host. + /// + /// `host` is the return value from the `host header` + pub fn verify_host(&self, host: &str) -> Result<(), Error> { + self.allowed_hosts.verify(host) + } + + /// Validate incoming request by origin. + /// + /// `host` is the return value from the `host header` + /// `origin` is the value from the `origin header`. + pub fn verify_origin(&self, origin: Option<&str>, host: &str) -> Result<(), Error> { + if let cors::AllowCors::Invalid = get_cors_allow_origin(origin, &self.allowed_origins, Some(host)) { + Err(Error::HttpHeaderRejected("origin", origin.unwrap_or("").into())) + } else { + Ok(()) + } + } + + /// Validate incoming request by CORS(`access-control-request-headers`). + /// + /// header_name: all keys of the header in the request + /// cors_request_headers: values of `access-control-request-headers` headers. + /// + pub fn verify_headers(&self, header_names: I, cors_request_headers: II) -> Result<(), Error> + where + T: AsRef, + I: Iterator, + II: Iterator, + { + let header = + cors::get_cors_allow_headers(header_names, cors_request_headers, &self.allowed_headers, |name| name); + + if let cors::AllowCors::Invalid = header { + Err(Error::HttpHeaderRejected( + "access-control-request-headers", + "".into(), + )) + } else { + Ok(()) + } + } + + /// Return the allowed headers we've set + pub fn allowed_headers(&self) -> &AllowHeaders { + &self.allowed_headers + } +} + +impl Default for AccessControl { + fn default() -> Self { + Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any } + } +} + +/// Convenience builder pattern +#[derive(Debug)] +pub struct AccessControlBuilder { + allowed_hosts: AllowHosts, + allowed_origins: Option>, + allowed_headers: AllowHeaders, +} + +impl Default for AccessControlBuilder { + fn default() -> Self { + Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any } + } +} + +impl AccessControlBuilder { + /// Create a new builder for `AccessControl`. + pub fn new() -> Self { + Self::default() + } + + /// Allow all hosts. + pub fn allow_all_hosts(mut self) -> Self { + self.allowed_hosts = AllowHosts::Any; + self + } + + /// Allow all origins. + pub fn allow_all_origins(mut self) -> Self { + self.allowed_origins = None; + self + } + + /// Allow all headers. + pub fn allow_all_headers(mut self) -> Self { + self.allowed_headers = AllowHeaders::Any; + self + } + + /// Configure allowed hosts. + /// + /// Default - allow all. + pub fn set_allowed_hosts(mut self, list: List) -> Result + where + List: IntoIterator, + H: Into, + { + let allowed_hosts: Vec<_> = list.into_iter().map(|s| Host::parse(&s.into())).map(Into::into).collect(); + if allowed_hosts.is_empty() { + return Err(Error::EmptyAllowList("Host")); + } + self.allowed_hosts = AllowHosts::Only(allowed_hosts); + Ok(self) + } + + /// Configure allowed origins. + /// + /// Default - allow all. + pub fn set_allowed_origins(mut self, list: List) -> Result + where + List: IntoIterator, + Origin: Into, + { + let allowed_origins: Vec = list.into_iter().map(Into::into).map(Into::into).collect(); + if allowed_origins.is_empty() { + return Err(Error::EmptyAllowList("Origin")); + } + self.allowed_origins = Some(allowed_origins); + Ok(self) + } + + /// Configure allowed CORS headers. + /// + /// Default - allow all. + pub fn set_allowed_headers(mut self, list: List) -> Result + where + List: IntoIterator, + Header: Into, + { + let allowed_headers: Vec = list.into_iter().map(Into::into).collect(); + if allowed_headers.is_empty() { + return Err(Error::EmptyAllowList("Header")); + } + self.allowed_headers = AllowHeaders::Only(allowed_headers); + Ok(self) + } + + /// Finalize the `AccessControl` settings. + pub fn build(self) -> AccessControl { + AccessControl { + allowed_hosts: self.allowed_hosts, + allowed_origins: self.allowed_origins, + allowed_headers: self.allowed_headers, + } + } +} diff --git a/core/src/server/mod.rs b/core/src/server/mod.rs index fe1a99277b..a1a00c0f91 100644 --- a/core/src/server/mod.rs +++ b/core/src/server/mod.rs @@ -26,6 +26,8 @@ //! Shared modules for the JSON-RPC servers. +/// Access control verification. +pub mod access_control; /// Helpers. pub mod helpers; /// Resource limiting. Create generic "resources" and configure their limits to ensure servers are not overloaded. diff --git a/examples/examples/cors_server.rs b/examples/examples/cors_server.rs index c2c5f52b6b..4374f8f175 100644 --- a/examples/examples/cors_server.rs +++ b/examples/examples/cors_server.rs @@ -26,7 +26,10 @@ use std::net::SocketAddr; -use jsonrpsee::http_server::{AccessControlBuilder, HttpServerBuilder, HttpServerHandle, RpcModule}; +use jsonrpsee::{ + core::server::access_control::AccessControlBuilder, + http_server::{HttpServerBuilder, HttpServerHandle, RpcModule}, +}; #[tokio::main] async fn main() -> anyhow::Result<()> { diff --git a/http-server/Cargo.toml b/http-server/Cargo.toml index 040edcf415..d7306442cf 100644 --- a/http-server/Cargo.toml +++ b/http-server/Cargo.toml @@ -15,15 +15,13 @@ futures-channel = "0.3.14" futures-util = { version = "0.3.14", default-features = false } jsonrpsee-types = { path = "../types", version = "0.13.1" } jsonrpsee-core = { path = "../core", version = "0.13.1", features = ["server", "http-helpers"] } -globset = "0.4" -lazy_static = "1.4" tracing = "0.1" serde_json = "1" tokio = { version = "1.16", features = ["rt-multi-thread", "macros"] } -unicase = "2.6.0" [dev-dependencies] env_logger = "0.9.0" +tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } jsonrpsee-test-utils = { path = "../test-utils" } jsonrpsee = { path = "../jsonrpsee", features = ["full"] } socket2 = "0.4" diff --git a/http-server/src/access_control/mod.rs b/http-server/src/access_control/mod.rs deleted file mode 100644 index b17d616443..0000000000 --- a/http-server/src/access_control/mod.rs +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -//! Access control based on HTTP headers - -pub(crate) mod cors; -pub(crate) mod hosts; -mod matcher; - -use cors::{AccessControlAllowHeaders, AccessControlAllowOrigin}; -use hosts::{AllowHosts, Host}; -use hyper::header; -use jsonrpsee_core::{http_helpers, Error}; - -/// Define access on control on HTTP layer. -#[derive(Clone, Debug)] -pub struct AccessControl { - allowed_hosts: AllowHosts, - allowed_origins: Option>, - allowed_headers: AccessControlAllowHeaders, - continue_on_invalid_cors: bool, -} - -impl AccessControl { - /// Validate incoming request by http HOST - pub fn deny_host(&self, request: &hyper::Request) -> bool { - !hosts::is_host_valid(http_helpers::read_header_value(request.headers(), "host"), &self.allowed_hosts) - } - - /// Validate incoming request by CORS origin - pub fn deny_cors_origin(&self, request: &hyper::Request) -> bool { - let header = cors::get_cors_allow_origin( - http_helpers::read_header_value(request.headers(), "origin"), - http_helpers::read_header_value(request.headers(), "host"), - &self.allowed_origins, - ) - .map(|origin| { - use self::cors::AccessControlAllowOrigin::*; - match origin { - Value(ref val) => { - header::HeaderValue::from_str(val).unwrap_or_else(|_| header::HeaderValue::from_static("null")) - } - Null => header::HeaderValue::from_static("null"), - Any => header::HeaderValue::from_static("*"), - } - }); - header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors - } - - /// Validate incoming request by CORS header - pub fn deny_cors_header(&self, request: &hyper::Request) -> bool { - let headers = request.headers().keys().map(|name| name.as_str()); - let requested_headers = http_helpers::read_header_values(request.headers(), "access-control-request-headers") - .filter_map(|val| val.to_str().ok()) - .flat_map(|val| val.split(", ")) - .flat_map(|val| val.split(',')); - - let header = cors::get_cors_allow_headers(headers, requested_headers, &self.allowed_headers, |name| { - header::HeaderValue::from_str(name).unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) - }); - header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors - } - - /// Return the allowed headers we've set - pub(crate) fn allowed_headers(&self) -> &AccessControlAllowHeaders { - &self.allowed_headers - } -} - -impl Default for AccessControl { - fn default() -> Self { - Self { - allowed_hosts: AllowHosts::Any, - allowed_origins: None, - allowed_headers: AccessControlAllowHeaders::Any, - continue_on_invalid_cors: false, - } - } -} - -/// Convenience builder pattern -#[derive(Debug)] -pub struct AccessControlBuilder { - allowed_hosts: AllowHosts, - allowed_origins: Option>, - allowed_headers: AccessControlAllowHeaders, - continue_on_invalid_cors: bool, -} - -impl Default for AccessControlBuilder { - fn default() -> Self { - Self { - allowed_hosts: AllowHosts::Any, - allowed_origins: None, - allowed_headers: AccessControlAllowHeaders::Any, - continue_on_invalid_cors: false, - } - } -} - -impl AccessControlBuilder { - /// Create a new builder for `AccessControl`. - pub fn new() -> Self { - Self::default() - } - - /// Allow all hosts. - pub fn allow_all_hosts(mut self) -> Self { - self.allowed_hosts = AllowHosts::Any; - self - } - - /// Allow all origins. - pub fn allow_all_origins(mut self) -> Self { - self.allowed_headers = AccessControlAllowHeaders::Any; - self - } - - /// Allow all headers. - pub fn allow_all_headers(mut self) -> Self { - self.allowed_origins = None; - self - } - - /// Configure allowed hosts. - /// - /// Default - allow all. - pub fn set_allowed_hosts(mut self, list: List) -> Result - where - List: IntoIterator, - H: Into, - { - let allowed_hosts: Vec = list.into_iter().map(Into::into).collect(); - if allowed_hosts.is_empty() { - return Err(Error::EmptyAllowList("Host")); - } - self.allowed_hosts = AllowHosts::Only(allowed_hosts); - Ok(self) - } - - /// Configure allowed origins. - /// - /// Default - allow all. - pub fn set_allowed_origins(mut self, list: List) -> Result - where - List: IntoIterator, - Origin: Into, - { - let allowed_origins: Vec = list.into_iter().map(Into::into).collect(); - if allowed_origins.is_empty() { - return Err(Error::EmptyAllowList("Origin")); - } - self.allowed_origins = Some(allowed_origins); - Ok(self) - } - - /// Configure allowed CORS headers. - /// - /// Default - allow all. - pub fn set_allowed_headers(mut self, list: List) -> Result - where - List: IntoIterator, - Header: Into, - { - let allowed_headers: Vec = list.into_iter().map(Into::into).collect(); - if allowed_headers.is_empty() { - return Err(Error::EmptyAllowList("Header")); - } - self.allowed_headers = AccessControlAllowHeaders::Only(allowed_headers); - Ok(self) - } - - /// Enable or disable to continue with invalid CORS. - /// - /// Default: false. - pub fn continue_on_invalid_cors(mut self, continue_on_invalid_cors: bool) -> Self { - self.continue_on_invalid_cors = continue_on_invalid_cors; - self - } - - /// Build. - pub fn build(self) -> AccessControl { - AccessControl { - allowed_hosts: self.allowed_hosts, - allowed_origins: self.allowed_origins, - allowed_headers: self.allowed_headers, - continue_on_invalid_cors: self.continue_on_invalid_cors, - } - } -} diff --git a/http-server/src/lib.rs b/http-server/src/lib.rs index d96c19ed51..fab3147bb8 100644 --- a/http-server/src/lib.rs +++ b/http-server/src/lib.rs @@ -30,17 +30,12 @@ //! //! `jsonrpsee-http-server` is a [JSON RPC](https://www.jsonrpc.org/specification) HTTPS server library that's is built for `async/await`. -mod access_control; mod server; /// Common builders for RPC responses. pub mod response; -pub use access_control::{ - cors::{AccessControlAllowHeaders, AccessControlAllowOrigin}, - hosts::{AllowHosts, DomainsValidation, Host}, - AccessControl, AccessControlBuilder, -}; +pub use jsonrpsee_core::server::access_control::{AccessControl, AccessControlBuilder}; pub use jsonrpsee_core::server::rpc_module::RpcModule; pub use jsonrpsee_types as types; pub use server::{Builder as HttpServerBuilder, Server as HttpServer, ServerHandle as HttpServerHandle}; diff --git a/http-server/src/server.rs b/http-server/src/server.rs index e276b64706..f2594a167f 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -30,8 +30,8 @@ use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::response; use crate::response::{internal_error, malformed}; -use crate::{response, AccessControl}; use futures_channel::mpsc; use futures_util::{future::join_all, stream::StreamExt, FutureExt}; use hyper::header::{HeaderMap, HeaderValue}; @@ -41,6 +41,7 @@ use hyper::{Error as HyperError, Method}; use jsonrpsee_core::error::{Error, GenericTransportError}; use jsonrpsee_core::http_helpers::{self, read_body}; use jsonrpsee_core::middleware::Middleware; +use jsonrpsee_core::server::access_control::AccessControl; use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink}; use jsonrpsee_core::server::resource_limiting::Resources; use jsonrpsee_core::server::rpc_module::{MethodKind, Methods}; @@ -53,6 +54,7 @@ use tokio::net::{TcpListener, ToSocketAddrs}; /// Builder to create JSON-RPC HTTP server. #[derive(Debug)] pub struct Builder { + /// Access control based on HTTP headers. access_control: AccessControl, resources: Resources, max_request_body_size: u32, @@ -67,11 +69,11 @@ pub struct Builder { impl Default for Builder { fn default() -> Self { Self { + access_control: AccessControl::default(), max_request_body_size: TEN_MB_SIZE_BYTES, max_response_body_size: TEN_MB_SIZE_BYTES, batch_requests_supported: true, resources: Resources::default(), - access_control: AccessControl::default(), tokio_runtime: None, middleware: (), health_api: None, @@ -114,11 +116,11 @@ impl Builder { /// ``` pub fn set_middleware(self, middleware: T) -> Builder { Builder { + access_control: self.access_control, max_request_body_size: self.max_request_body_size, max_response_body_size: self.max_response_body_size, batch_requests_supported: self.batch_requests_supported, resources: self.resources, - access_control: self.access_control, tokio_runtime: self.tokio_runtime, middleware, health_api: self.health_api, @@ -215,9 +217,9 @@ impl Builder { local_addr: SocketAddr, ) -> Result, Error> { Ok(Server { + access_control: self.access_control, listener, local_addr: Some(local_addr), - access_control: self.access_control, max_request_body_size: self.max_request_body_size, max_response_body_size: self.max_response_body_size, batch_requests_supported: self.batch_requests_supported, @@ -358,9 +360,9 @@ pub struct Server { max_response_body_size: u32, /// Whether batch requests are supported by this server or not. batch_requests_supported: bool, - /// Access control + /// Access control. access_control: AccessControl, - /// Tracker for currently used resources on the server + /// Tracker for currently used resources on the server. resources: Resources, /// Custom tokio runtime to run the server on. tokio_runtime: Option, @@ -378,7 +380,7 @@ impl Server { pub fn start(mut self, methods: impl Into) -> Result { let max_request_body_size = self.max_request_body_size; let max_response_body_size = self.max_response_body_size; - let access_control = self.access_control; + let acl = self.access_control; let (tx, mut rx) = mpsc::channel(1); let listener = self.listener; let resources = self.resources; @@ -389,7 +391,7 @@ impl Server { let make_service = make_service_fn(move |_| { let methods = methods.clone(); - let access_control = access_control.clone(); + let acl = acl.clone(); let resources = resources.clone(); let middleware = middleware.clone(); let health_api = health_api.clone(); @@ -397,7 +399,7 @@ impl Server { async move { Ok::<_, HyperError>(service_fn(move |request| { let methods = methods.clone(); - let access_control = access_control.clone(); + let acl = acl.clone(); let resources = resources.clone(); let middleware = middleware.clone(); let health_api = health_api.clone(); @@ -405,8 +407,28 @@ impl Server { // Run some validation on the http request, then read the body and try to deserialize it into one of // two cases: a single RPC request or a batch of RPC requests. async move { - if let Err(e) = access_control_is_valid(&access_control, &request) { - return Ok::<_, HyperError>(e); + let keys = request.headers().keys().map(|k| k.as_str()); + let cors_request_headers = http_helpers::get_cors_request_headers(request.headers()); + + let host = match http_helpers::read_header_value(request.headers(), "host") { + Some(origin) => origin, + None => return Ok(malformed()), + }; + let maybe_origin = http_helpers::read_header_value(request.headers(), "origin"); + + if let Err(e) = acl.verify_host(host) { + tracing::warn!("Denied request: {:?}", e); + return Ok(response::host_not_allowed()); + } + + if let Err(e) = acl.verify_origin(maybe_origin, host) { + tracing::warn!("Denied request: {:?}", e); + return Ok(response::invalid_allow_origin()); + } + + if let Err(e) = acl.verify_headers(keys, cors_request_headers) { + tracing::warn!("Denied request: {:?}", e); + return Ok(response::invalid_allow_headers()); } // Only `POST` and `OPTIONS` methods are allowed. @@ -414,11 +436,12 @@ impl Server { // An OPTIONS request is a CORS preflight request. We've done our access check // above so we just need to tell the browser that the request is OK. Method::OPTIONS => { - let origin = match http_helpers::read_header_value(request.headers(), "origin") { + let origin = match maybe_origin { Some(origin) => origin, None => return Ok(malformed()), }; - let allowed_headers = access_control.allowed_headers().to_cors_header_value(); + + let allowed_headers = acl.allowed_headers().to_cors_header_value(); let allowed_header_bytes = allowed_headers.as_bytes(); let res = hyper::Response::builder() @@ -497,23 +520,6 @@ fn return_origin_if_different_from_host(headers: &HeaderMap) -> Option<&HeaderVa } } -// Checks to that access control of the received request is the same as configured. -fn access_control_is_valid( - access_control: &AccessControl, - request: &hyper::Request, -) -> Result<(), hyper::Response> { - if access_control.deny_host(request) { - return Err(response::host_not_allowed()); - } - if access_control.deny_cors_origin(request) { - return Err(response::invalid_allow_origin()); - } - if access_control.deny_cors_header(request) { - return Err(response::invalid_allow_headers()); - } - Ok(()) -} - /// Checks that content type of received request is valid for JSON-RPC. fn content_type_is_json(request: &hyper::Request) -> bool { is_json(request.headers().get("content-type")) diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 86760db67f..03c7bb41c4 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -29,7 +29,8 @@ use std::time::Duration; use futures::{SinkExt, StreamExt}; use jsonrpsee::core::error::SubscriptionClosed; -use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle}; +use jsonrpsee::core::server::access_control::{AccessControl, AccessControlBuilder}; +use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle}; use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR}; use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle}; use jsonrpsee::RpcModule; @@ -218,7 +219,7 @@ pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::m } pub async fn http_server() -> (SocketAddr, HttpServerHandle) { - http_server_with_access_control(AccessControl::default()).await + http_server_with_access_control(AccessControlBuilder::default().build()).await } pub async fn http_server_with_access_control(acl: AccessControl) -> (SocketAddr, HttpServerHandle) { diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index c9dce9704f..13c7ac0d21 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -36,6 +36,7 @@ use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT use jsonrpsee::core::error::SubscriptionClosed; use jsonrpsee::core::{Error, JsonValue}; use jsonrpsee::http_client::HttpClientBuilder; +use jsonrpsee::http_server::AccessControlBuilder; use jsonrpsee::rpc_params; use jsonrpsee::types::error::ErrorObject; use jsonrpsee::ws_client::WsClientBuilder; @@ -769,3 +770,47 @@ async fn http_health_api_works() { let out = String::from_utf8(bytes.to_vec()).unwrap(); assert_eq!(out, "{\"jsonrpc\":\"2.0\",\"result\":\"im ok\",\"id\":0}"); } + +#[tokio::test] +async fn ws_host_filtering_wildcard_works() { + use jsonrpsee::ws_server::*; + + let acl = AccessControlBuilder::default() + .set_allowed_hosts(vec!["http://localhost:*", "http://127.0.0.1:*"]) + .unwrap() + .build(); + + let server = WsServerBuilder::default().set_access_control(acl).build("127.0.0.1:0").await.unwrap(); + let mut module = RpcModule::new(()); + let addr = server.local_addr().unwrap(); + module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); + + let _handle = server.start(module).unwrap(); + + let server_url = format!("ws://{}", addr); + let client = WsClientBuilder::default().build(&server_url).await.unwrap(); + + assert!(client.request::("say_hello", None).await.is_ok()); +} + +#[tokio::test] +async fn http_host_filtering_wildcard_works() { + use jsonrpsee::http_server::*; + + let acl = AccessControlBuilder::default() + .set_allowed_hosts(vec!["http://localhost:*", "http://127.0.0.1:*"]) + .unwrap() + .build(); + + let server = HttpServerBuilder::default().set_access_control(acl).build("127.0.0.1:0").await.unwrap(); + let mut module = RpcModule::new(()); + let addr = server.local_addr().unwrap(); + module.register_method("say_hello", |_, _| Ok("hello")).unwrap(); + + let _handle = server.start(module).unwrap(); + + let server_url = format!("http://{}", addr); + let client = HttpClientBuilder::default().build(&server_url).unwrap(); + + assert!(client.request::("say_hello", None).await.is_ok()); +} diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 4c7425dc45..b49a48305b 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -39,6 +39,7 @@ use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::StreamExt; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; use jsonrpsee_core::middleware::Middleware; +use jsonrpsee_core::server::access_control::AccessControl; use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, BoundedSubscriptions, MethodSink}; use jsonrpsee_core::server::resource_limiting::Resources; use jsonrpsee_core::server::rpc_module::{ConnState, ConnectionId, MethodKind, Methods}; @@ -244,8 +245,19 @@ where tracing::debug!("Accepting new connection: {}", conn_id); let key = { let req = server.receive_request().await?; - let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host)); - let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin); + + let host = std::str::from_utf8(req.headers().host) + .map_err(|_e| Error::HttpHeaderRejected("Host", "Invalid UTF-8".to_string()))?; + let origin = req.headers().origin.and_then(|h| { + let res = std::str::from_utf8(h).ok(); + if res.is_none() { + tracing::warn!("Origin header invalid UTF-8; treated as no Origin header"); + } + res + }); + + let host_check = cfg.access_control.verify_host(host); + let origin_check = cfg.access_control.verify_origin(origin, host); host_check.and(origin_check).map(|()| req.key()) }; @@ -255,11 +267,12 @@ where let accept = Response::Accept { key, protocol: None }; server.send_response(&accept).await?; } - Err(error) => { + Err(err) => { + tracing::warn!("Rejected connection: {:?}", err); let reject = Response::Reject { status_code: 403 }; server.send_response(&reject).await?; - return Err(error); + return Err(err); } } @@ -629,26 +642,6 @@ async fn background_task( result } -#[derive(Debug, Clone)] -enum AllowedValue { - Any, - OneOf(Box<[String]>), -} - -impl AllowedValue { - fn verify(&self, header: &str, value: Option<&[u8]>) -> Result<(), Error> { - if let (AllowedValue::OneOf(list), Some(value)) = (self, value) { - if !list.iter().any(|o| o.as_bytes() == value) { - let error = format!("{} denied: {}", header, String::from_utf8_lossy(value)); - tracing::warn!("{}", error); - return Err(Error::Custom(error)); - } - } - - Ok(()) - } -} - /// JSON-RPC Websocket server settings. #[derive(Debug, Clone)] struct Settings { @@ -660,10 +653,8 @@ struct Settings { max_connections: u64, /// Maximum number of subscriptions per connection. max_subscriptions_per_connection: u32, - /// Policy by which to accept or deny incoming requests based on the `Origin` header. - allowed_origins: AllowedValue, - /// Policy by which to accept or deny incoming requests based on the `Host` header. - allowed_hosts: AllowedValue, + /// Access control based on HTTP headers + access_control: AccessControl, /// Whether batch requests are supported by this server or not. batch_requests_supported: bool, /// Custom tokio runtime to run the server on. @@ -678,8 +669,7 @@ impl Default for Settings { max_subscriptions_per_connection: 1024, max_connections: MAX_CONNECTIONS, batch_requests_supported: true, - allowed_origins: AllowedValue::Any, - allowed_hosts: AllowedValue::Any, + access_control: AccessControl::default(), tokio_runtime: None, } } @@ -754,35 +744,6 @@ impl Builder { Ok(self) } - /// Set a list of allowed origins. During the handshake, the `Origin` header will be - /// checked against the list, connections without a matching origin will be denied. - /// Values should be hostnames with protocol. - /// - /// ```rust - /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); - /// builder.set_allowed_origins(["https://example.com"]); - /// ``` - /// - /// By default allows any `Origin`. - /// - /// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the - /// default. - pub fn set_allowed_origins(mut self, list: List) -> Result - where - List: IntoIterator, - Origin: Into, - { - let list: Box<_> = list.into_iter().map(Into::into).collect(); - - if list.len() == 0 { - return Err(Error::EmptyAllowList("Origin")); - } - - self.settings.allowed_origins = AllowedValue::OneOf(list); - - Ok(self) - } - /// Add a middleware to the builder [`Middleware`](../jsonrpsee_core/middleware/trait.Middleware.html). /// /// ``` @@ -812,49 +773,6 @@ impl Builder { Builder { settings: self.settings, resources: self.resources, middleware, id_provider: self.id_provider } } - /// Restores the default behavior of allowing connections with `Origin` header - /// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins). - pub fn allow_all_origins(mut self) -> Self { - self.settings.allowed_origins = AllowedValue::Any; - self - } - - /// Set a list of allowed hosts. During the handshake, the `Host` header will be - /// checked against the list. Connections without a matching host will be denied. - /// Values should be hostnames without protocol. - /// - /// ```rust - /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); - /// builder.set_allowed_hosts(["example.com"]); - /// ``` - /// - /// By default allows any `Host`. - /// - /// Will return an error if `list` is empty. Use [`allow_all_hosts`](Builder::allow_all_hosts) to restore the - /// default. - pub fn set_allowed_hosts(mut self, list: List) -> Result - where - List: IntoIterator, - Host: Into, - { - let list: Box<_> = list.into_iter().map(Into::into).collect(); - - if list.len() == 0 { - return Err(Error::EmptyAllowList("Host")); - } - - self.settings.allowed_hosts = AllowedValue::OneOf(list); - - Ok(self) - } - - /// Restores the default behavior of allowing connections with `Host` header - /// containing any value. This will undo any list set by [`set_allowed_hosts`](Builder::set_allowed_hosts). - pub fn allow_all_hosts(mut self) -> Self { - self.settings.allowed_hosts = AllowedValue::Any; - self - } - /// Configure a custom [`tokio::runtime::Handle`] to run the server on. /// /// Default: [`tokio::spawn`] @@ -888,6 +806,12 @@ impl Builder { self } + /// Sets access control settings. + pub fn set_access_control(mut self, acl: AccessControl) -> Self { + self.settings.access_control = acl; + self + } + /// Finalize the configuration of the server. Consumes the [`Builder`]. /// /// ```rust