Skip to content

Commit

Permalink
feat(middleware): add HostFilterLater::disable (#1213)
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 authored Oct 11, 2023
1 parent 684c946 commit bccf49c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
34 changes: 30 additions & 4 deletions server/src/middleware/host_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use tower::{Layer, Service};

/// Middleware to enable host filtering.
#[derive(Debug)]
pub struct HostFilterLayer(Arc<WhitelistedHosts>);
pub struct HostFilterLayer(Option<Arc<WhitelistedHosts>>);

impl HostFilterLayer {
/// Enables host filtering and allow only the specified hosts.
Expand All @@ -49,7 +49,33 @@ impl HostFilterLayer {
U: TryInto<Authority, Error = AuthorityError>,
{
let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
Ok(Self(Arc::new(WhitelistedHosts::from(allow_only?))))
Ok(Self(Some(Arc::new(WhitelistedHosts::from(allow_only?)))))
}

/// Convenience method to disable host filtering but less efficient
/// than to not enable the middleware at all.
///
/// Because is the `tower middleware` returns a different type
/// depending on which Layers are configured it and may not compile
/// in some contexts.
///
/// For example the following won't compile:
///
/// ```ignore
/// use jsonrpsee_server::middleware::{ProxyGetRequestLayer, HostFilterLayer};
///
/// let host_filter = false;
///
/// let middleware = if host_filter {
/// tower::ServiceBuilder::new()
/// .layer(HostFilterLayer::new(["example.com"]).unwrap())
/// .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap())
/// } else {
/// tower::ServiceBuilder::new()
/// };
/// ```
pub fn disable() -> Self {
Self(None)
}
}

Expand All @@ -65,7 +91,7 @@ impl<S> Layer<S> for HostFilterLayer {
#[derive(Debug)]
pub struct HostFilter<S> {
inner: S,
filter: Arc<WhitelistedHosts>,
filter: Option<Arc<WhitelistedHosts>>,
}

impl<S> Service<Request<Body>> for HostFilter<S>
Expand All @@ -88,7 +114,7 @@ where
return async { Ok(http::response::malformed()) }.boxed();
};

if self.filter.recognize(&authority) {
if self.filter.as_ref().map_or(true, |f| f.recognize(&authority)) {
Box::pin(self.inner.call(request).map_err(Into::into))
} else {
tracing::debug!("Denied request: {:?}", request);
Expand Down
29 changes: 29 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,35 @@ async fn deny_invalid_host() {
}
}

#[tokio::test]
async fn disable_host_filter_works() {
use jsonrpsee::server::*;

init_logger();

let middleware = tower::ServiceBuilder::new().layer(HostFilterLayer::disable());

let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0").await.unwrap();
let mut module = RpcModule::new(());
let addr = server.local_addr().unwrap();
module.register_method("say_hello", |_, _| "hello").unwrap();

let _handle = server.start(module);

// HTTP
{
let server_url = format!("http://{}", addr);
let client = HttpClientBuilder::default().build(&server_url).unwrap();
assert!(client.request::<String, _>("say_hello", rpc_params![]).await.is_ok());
}

// WebSocket
{
let server_url = format!("ws://{}", addr);
assert!(WsClientBuilder::default().build(&server_url).await.is_ok());
}
}

#[tokio::test]
async fn subscription_option_err_is_not_sent() {
init_logger();
Expand Down

0 comments on commit bccf49c

Please sign in to comment.