Skip to content

Commit

Permalink
Implement Serde (De)serialization for Cors struct (lawliet89#7)
Browse files Browse the repository at this point in the history
* Add serde support for Method

* "Turn on" serde

* Add Serde support for UniCase

* Fix merge error

* Add default tests
  • Loading branch information
lawliet89 authored and v1olen committed Mar 18, 2021
1 parent cc8dc20 commit 582872f
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 32 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ log = "0.3"
rocket = "0.3"
serde = "1.0"
serde_derive = "1.0"
unicase="1.4"
unicase = "2.0"
unicase_serde = "0.1.0"
url = "1.5.1"
url_serde = "0.2.0"

Expand All @@ -31,3 +32,4 @@ version_check = "0.1"
hyper = "0.10"
rocket_codegen = "0.3"
serde_json = "1.0"
serde_test = "1.0"
65 changes: 56 additions & 9 deletions src/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,56 @@ use std::ops::Deref;
use std::str::FromStr;

use rocket::{self, Outcome};
use rocket::http::{Method, Status};
use rocket::http::Status;
use rocket::request::{self, FromRequest};
use unicase::UniCase;
use unicase_serde;
use url;
use url_serde;

pub(crate) type HeaderFieldName = UniCase<String>;
pub(crate) type HeaderFieldNamesSet = HashSet<HeaderFieldName>;
/// A case insensitive header name
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug, Hash)]
pub struct HeaderFieldName(
#[serde(with = "unicase_serde::unicase")]
UniCase<String>
);

impl Deref for HeaderFieldName {
type Target = String;

fn deref(&self) -> &Self::Target {
self.0.deref()
}
}

impl fmt::Display for HeaderFieldName {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

impl<'a> From<&'a str> for HeaderFieldName {
fn from(s: &'a str) -> Self {
HeaderFieldName(From::from(s))
}
}

impl<'a> From<String> for HeaderFieldName {
fn from(s: String) -> Self {
HeaderFieldName(From::from(s))
}
}

impl FromStr for HeaderFieldName {
type Err = <String as FromStr>::Err;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(HeaderFieldName(FromStr::from_str(s)?))
}
}

/// A set of case insensitive header names
pub type HeaderFieldNamesSet = HashSet<HeaderFieldName>;

/// A wrapped `url::Url` to allow for deserialization
#[derive(Eq, PartialEq, Clone, Hash, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -72,13 +114,13 @@ pub type Origin = Url;
/// You can use this as a rocket [Request Guard](https://rocket.rs/guide/requests/#request-guards)
/// to ensure that the header is passed in correctly.
#[derive(Debug)]
pub struct AccessControlRequestMethod(pub Method);
pub struct AccessControlRequestMethod(pub ::Method);

impl FromStr for AccessControlRequestMethod {
type Err = rocket::Error;

fn from_str(method: &str) -> Result<Self, Self::Err> {
Ok(AccessControlRequestMethod(Method::from_str(method)?))
Ok(AccessControlRequestMethod(::Method::from_str(method)?))
}
}

Expand Down Expand Up @@ -117,7 +159,7 @@ impl FromStr for AccessControlRequestHeaders {

let set: HeaderFieldNamesSet = headers
.split(',')
.map(|header| UniCase(header.trim().to_string()))
.map(|header| From::from(header.trim().to_string()))
.collect();
Ok(AccessControlRequestHeaders(set))
}
Expand Down Expand Up @@ -149,7 +191,6 @@ mod tests {
use hyper;
use rocket;
use rocket::local::Client;
use rocket::http::Method;

use super::*;

Expand Down Expand Up @@ -194,11 +235,17 @@ mod tests {
fn request_method_conversion() {
let method = "POST";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Post));
assert_matches!(
parsed_method,
AccessControlRequestMethod(::Method(rocket::http::Method::Post))
);

let method = "options";
let parsed_method = not_err!(AccessControlRequestMethod::from_str(method));
assert_matches!(parsed_method, AccessControlRequestMethod(Method::Options));
assert_matches!(
parsed_method,
AccessControlRequestMethod(::Method(rocket::http::Method::Options))
);

let method = "INVALID";
let _ = is_err!(AccessControlRequestMethod::from_str(method));
Expand Down
Loading

0 comments on commit 582872f

Please sign in to comment.