From 6292d3a3ae55bb1c54927196e0fa1c4f15c5e0d8 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 23 May 2023 20:36:59 -0700 Subject: [PATCH] Consolidate serialization of HelperIdentity --- src/helpers/mod.rs | 19 +++++++++++++++---- src/net/client/mod.rs | 5 ++++- src/net/server/mod.rs | 16 ++++------------ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index 837b97121..c4ab2ec1f 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -58,13 +58,24 @@ pub const MESSAGE_PAYLOAD_SIZE_BYTES: usize = MessagePayloadArrayLen::USIZE; #[derive(Copy, Clone, Eq, PartialEq, Hash)] #[cfg_attr( feature = "enable-serde", - derive(serde::Serialize, serde::Deserialize), - serde(transparent) + derive(serde::Deserialize), + serde(try_from = "usize") )] pub struct HelperIdentity { id: u8, } +// Serialize as `serde(transparent)` would. Don't see how to enable that +// for only one of (de)serialization. +impl serde::Serialize for HelperIdentity { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + self.id.serialize(serializer) + } +} + impl TryFrom for HelperIdentity { type Error = String; @@ -99,8 +110,8 @@ impl Debug for HelperIdentity { #[cfg(feature = "web-app")] impl From for hyper::header::HeaderValue { fn from(id: HelperIdentity) -> Self { - // does not implement `From` - hyper::header::HeaderValue::from(u16::from(id.id)) + // panic if serializing an integer fails, or is not ASCII + hyper::header::HeaderValue::try_from(serde_json::to_string(&id).unwrap()).unwrap() } } diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 8ec210ac9..95d05b830 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -95,7 +95,10 @@ impl MpcHelperClient { error!("certificate identity ignored for HTTP client"); None } - ClientIdentity::Helper(id) => Some((HTTP_CLIENT_ID_HEADER.clone(), id.into())), + ClientIdentity::Helper(id) => Some(( + HTTP_CLIENT_ID_HEADER.clone(), + id.try_into().expect("integer not ascii?"), + )), ClientIdentity::None => None, }; (HttpsConnector::new(), auth_header) diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index eea169c63..38b2f4b3d 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -32,7 +32,6 @@ use std::{ io, net::{Ipv4Addr, SocketAddr, TcpListener}, ops::Deref, - str::FromStr, task::{Context, Poll}, }; use tokio_rustls::{ @@ -447,7 +446,7 @@ pub static HTTP_CLIENT_ID_HEADER: HeaderName = /// Since this allows a client to claim any identity, it is completely /// insecure. It must only be used in contexts where that is acceptable. #[derive(Clone)] -pub(super) struct SetClientIdentityFromHeader { +struct SetClientIdentityFromHeader { inner: S, } @@ -470,16 +469,9 @@ impl, Response = Response>> Service> } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(header_value) = req.headers().get(HTTP_CLIENT_ID_HEADER.clone()) { - let id_result = header_value - .to_str() - .map_err(Into::into) - .and_then(|value_str| usize::from_str(value_str).map_err(Into::into)) - .and_then(|value_int| { - HelperIdentity::try_from(value_int).map_err(|e| { - Error::InvalidHeader(format!("{HTTP_CLIENT_ID_HEADER}: {e:?}").into()) - }) - }); + if let Some(header_value) = req.headers().get(&HTTP_CLIENT_ID_HEADER) { + let id_result = serde_json::from_slice(header_value.as_ref()) + .map_err(|e| Error::InvalidHeader(format!("{HTTP_CLIENT_ID_HEADER}: {e}").into())); match id_result { Ok(id) => req.extensions_mut().insert(ClientIdentity(id)), Err(err) => return ready(Ok(err.into_response())).right_future(),