Skip to content

Commit

Permalink
Move CSRF plugin code
Browse files Browse the repository at this point in the history
Migrate to new test harness.
Fix add missing tests.
  • Loading branch information
bryn committed Feb 11, 2025
1 parent 7041f1b commit 7056a1d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 67 deletions.
2 changes: 2 additions & 0 deletions apollo-router/src/plugins/csrf/fixtures/default.router.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
csrf:
unsafe_disabled: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
csrf:
required_headers:
- X-MY-CSRF-Token
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
csrf:
unsafe_disabled: true
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub(crate) struct CSRFConfig {
/// set unsafe_disabled = true to disable the plugin behavior
/// Note that setting this to true is deemed unsafe.
/// See <https://developer.mozilla.org/en-US/docs/Glossary/CSRF>.
/// TODO rename this to enabled. This is in line with the other plugins and will be less confusing.
unsafe_disabled: bool,
/// Override the headers to check for by setting
/// custom_headers
Expand Down Expand Up @@ -209,7 +210,6 @@ register_plugin!("apollo", "csrf", Csrf);

#[cfg(test)]
mod csrf_tests {
use crate::plugin::PluginInit;
#[tokio::test]
async fn plugin_registered() {
crate::plugin::plugins()
Expand All @@ -228,126 +228,160 @@ mod csrf_tests {
}

use http::header::CONTENT_TYPE;
use http_body_util::BodyExt;
use mime::APPLICATION_JSON;
use serde_json_bytes::json;
use tower::ServiceExt;

use super::*;
use crate::plugin::test::MockRouterService;
use crate::graphql;
use crate::plugins::test::PluginTestHarness;

#[tokio::test]
async fn it_lets_preflighted_request_pass_through() {
let config = CSRFConfig::default();
let with_preflight_content_type = router::Request::fake_builder()
.header(CONTENT_TYPE, APPLICATION_JSON.essence_str())
.build()
.unwrap();
assert_accepted(config.clone(), with_preflight_content_type).await;
assert_accepted(
include_str!("fixtures/default.router.yaml"),
with_preflight_content_type,
)
.await;

let with_preflight_header = router::Request::fake_builder()
.header("apollo-require-preflight", "this-is-a-test")
.build()
.unwrap();
assert_accepted(config, with_preflight_header).await;
assert_accepted(
include_str!("fixtures/default.router.yaml"),
with_preflight_header,
)
.await;
}

#[tokio::test]
async fn it_rejects_non_preflighted_headers_request() {
let config = CSRFConfig::default();
let mut non_preflighted_request = router::Request::fake_builder().build().unwrap();
// fake_builder defaults to `Content-Type: application/json`,
// specifically to avoid the case we’re testing here.
non_preflighted_request
.router_request
.headers_mut()
.remove("content-type");
assert_rejected(config, non_preflighted_request).await
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await
}

#[tokio::test]
async fn it_rejects_non_preflighted_content_type_request() {
let config = CSRFConfig::default();
let non_preflighted_request = router::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain")
.build()
.unwrap();
assert_rejected(config.clone(), non_preflighted_request).await;
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await;

let non_preflighted_request = router::Request::fake_builder()
.header(CONTENT_TYPE, "text/plain; charset=utf8")
.build()
.unwrap();
assert_rejected(config, non_preflighted_request).await;
assert_rejected(
include_str!("fixtures/default.router.yaml"),
non_preflighted_request,
)
.await;
}

#[tokio::test]
async fn it_accepts_non_preflighted_headers_request_when_plugin_is_disabled() {
let config = CSRFConfig {
unsafe_disabled: true,
..Default::default()
};
let non_preflighted_request = router::Request::fake_builder().build().unwrap();
assert_accepted(config, non_preflighted_request).await
assert_accepted(
include_str!("fixtures/unsafe_disabled.router.yaml"),
non_preflighted_request,
)
.await
}

async fn assert_accepted(config: CSRFConfig, request: router::Request) {
let mut mock_service = MockRouterService::new();
mock_service.expect_call().times(1).returning(move |_| {
Ok(router::Response::fake_builder()
.data(json!({ "test": 1234_u32 }))
.build()
.unwrap())
});
#[tokio::test]
async fn it_rejects_non_preflighted_headers_request_when_required_headers_are_not_present() {
let non_preflighted_request = router::Request::fake_builder().build().unwrap();
assert_rejected(
include_str!("fixtures/required_headers.router.yaml"),
non_preflighted_request,
)
.await
}

let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.router_service(mock_service.boxed());
let res = service_stack
.oneshot(request)
// Check that when the headers are present, the request is accepted
#[tokio::test]
async fn it_accepts_non_preflighted_headers_request_when_required_headers_are_present() {
let non_preflighted_request = router::Request::fake_builder()
.header("X-MY-CSRF-Token", "this-is-a-test")
.build()
.unwrap();
assert_accepted(
include_str!("fixtures/required_headers.router.yaml"),
non_preflighted_request,
)
.await
}

async fn assert_accepted(config: &'static str, request: router::Request) {
let plugin = PluginTestHarness::<Csrf>::builder()
.config(config)
.build()
.await;
let router_service =
plugin.router_service(|_r| async { router::Response::fake_builder().build() });
let mut resp = router_service
.call(request)
.await
.unwrap()
.next_response()
.expect("expected response");

let body = resp
.response
.body_mut()
.collect()
.await
.unwrap()
.unwrap();
.expect("expected body");

let json: serde_json::Value = serde_json::from_slice(&res).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"data": {
"test": 1234
}
}
"#);
let response: graphql::Response = serde_json::from_slice(&body.to_bytes()).unwrap();
assert_eq!(response.errors.len(), 0);
}

async fn assert_rejected(config: CSRFConfig, request: router::Request) {
let service_stack = Csrf::new(PluginInit::fake_new(config, Default::default()))
.await
.unwrap()
.router_service(MockRouterService::new().boxed());
let res = service_stack
.oneshot(request)
async fn assert_rejected(config: &'static str, request: router::Request) {
let plugin = PluginTestHarness::<Csrf>::builder()
.config(config)
.build()
.await;
let router_service =
plugin.router_service(|_r| async { router::Response::fake_builder().build() });
let mut resp = router_service
.call(request)
.await
.unwrap()
.next_response()
.expect("expected response");

let body = resp
.response
.body_mut()
.collect()
.await
.unwrap()
.unwrap();
.expect("expected body");

let json: serde_json::Value = serde_json::from_slice(&res).unwrap();
insta::assert_json_snapshot!(json, @r#"
{
"errors": [
{
"message": "This operation has been blocked as a potential Cross-Site Request Forgery (CSRF). Please either specify a 'content-type' header (with a mime-type that is not one of application/x-www-form-urlencoded, multipart/form-data, text/plain) or provide one of the following headers: x-apollo-operation-name, apollo-require-preflight",
"extensions": {
"code": "CSRF_ERROR"
}
}
]
}
"#);
let response: graphql::Response = serde_json::from_slice(&body.to_bytes()).unwrap();
assert_eq!(response.errors.len(), 1);
assert_eq!(
response.errors[0]
.extensions
.get("code")
.expect("error code")
.as_str(),
Some("CSRF_ERROR")
);
}
}

0 comments on commit 7056a1d

Please sign in to comment.