diff --git a/Cargo.toml b/Cargo.toml index b40c2b52c..dcd1e41fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,12 +17,12 @@ edition = "2018" all-features = true [dependencies] -async-compression = { version = "0.3.1", features = ["brotli", "deflate", "gzip", "stream"], optional = true } -bytes = "0.5" +async-compression = { git = "https://github.com/aknuds1/async-compression", rev = "e4d903b8ff9972f056f714c61bd9d0f3321a4463", features = ["brotli", "deflate", "gzip", "stream"], optional = true } +bytes = "0.6" futures = { version = "0.3", default-features = false, features = ["alloc"] } headers = "0.3" http = "0.2" -hyper = { version = "0.13", features = ["stream"] } +hyper = { git = "https://github.com/hyperium/hyper.git", rev = "1ba2a141a6f8736446ff4a0111df347c0dc66f6c", features = ["stream", "server", "http1", "http2", "tcp", "client"] } log = "0.4" mime = "0.3" mime_guess = "2.0.0" @@ -31,15 +31,15 @@ scoped-tls = "1.0" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.7" -tokio = { version = "0.2", features = ["fs", "stream", "sync", "time"] } +tokio = { version = "0.3", features = ["fs", "stream", "sync", "time"] } tracing = { version = "0.1", default-features = false, features = ["log", "std"] } tracing-futures = { version = "0.2", default-features = false, features = ["std-future"] } tower-service = "0.3" # tls is enabled by default, we don't want that yet -tokio-tungstenite = { version = "0.11", default-features = false, optional = true } +tokio-tungstenite = { git = "https://github.com/snapview/tokio-tungstenite.git", rev = "71a5b72123db32b318d48964948fc76c943f1548", default-features = false, optional = true } percent-encoding = "2.1" pin-project = "1.0" -tokio-rustls = { version = "0.14", optional = true } +tokio-rustls = { version = "0.21", optional = true } [dev-dependencies] pretty_env_logger = "0.4" @@ -47,7 +47,7 @@ tracing-subscriber = "0.2.7" tracing-log = "0.1" serde_derive = "1.0" handlebars = "3.0.0" -tokio = { version = "0.2", features = ["macros"] } +tokio = { version = "0.3", features = ["macros", "rt-multi-thread"] } listenfd = "0.3" [features] @@ -78,7 +78,6 @@ required-features = ["compression"] [[example]] name = "unix_socket" -required-features = ["tokio/uds"] [[example]] name = "websockets" diff --git a/examples/futures.rs b/examples/futures.rs index 013428093..43bf2f6ef 100644 --- a/examples/futures.rs +++ b/examples/futures.rs @@ -16,7 +16,7 @@ async fn main() { } async fn sleepy(Seconds(seconds): Seconds) -> Result { - tokio::time::delay_for(Duration::from_secs(seconds)).await; + tokio::time::sleep(Duration::from_secs(seconds)).await; Ok(format!("I waited {} seconds!", seconds)) } diff --git a/examples/unix_socket.rs b/examples/unix_socket.rs index 951a28782..56010bb81 100644 --- a/examples/unix_socket.rs +++ b/examples/unix_socket.rs @@ -6,9 +6,8 @@ use tokio::net::UnixListener; async fn main() { pretty_env_logger::init(); - let mut listener = UnixListener::bind("/tmp/warp.sock").unwrap(); - let incoming = listener.incoming(); + let listener = UnixListener::bind("/tmp/warp.sock").unwrap(); warp::serve(warp::fs::dir("examples/dir")) - .run_incoming(incoming) + .run_incoming(listener) .await; } diff --git a/src/filters/body.rs b/src/filters/body.rs index 0def35a4e..4457a52d8 100644 --- a/src/filters/body.rs +++ b/src/filters/body.rs @@ -7,7 +7,7 @@ use std::fmt; use std::pin::Pin; use std::task::{Context, Poll}; -use bytes::{buf::BufExt, Buf, Bytes}; +use bytes::{Buf, Bytes}; use futures::{future, ready, Stream, TryFutureExt}; use headers::ContentLength; use http::header::CONTENT_TYPE; diff --git a/src/filters/fs.rs b/src/filters/fs.rs index 1f04f6bb5..3a833247a 100644 --- a/src/filters/fs.rs +++ b/src/filters/fs.rs @@ -8,9 +8,10 @@ use std::io; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; +use std::task::{Poll, Context}; +use std::mem::MaybeUninit; -use bytes::{Bytes, BytesMut}; +use bytes::{Bytes, BytesMut, BufMut}; use futures::future::Either; use futures::{future, ready, stream, FutureExt, Stream, StreamExt, TryFutureExt}; use headers::{ @@ -22,7 +23,7 @@ use hyper::Body; use mime_guess; use percent_encoding::percent_decode_str; use tokio::fs::File as TkFile; -use tokio::io::AsyncRead; +use tokio::io::{AsyncSeekExt, AsyncRead, ReadBuf}; use crate::filter::{Filter, FilterClone, One}; use crate::reject::{self, Rejection}; @@ -419,7 +420,7 @@ fn file_stream( } reserve_at_least(&mut buf, buf_size); - let n = match ready!(Pin::new(&mut f).poll_read_buf(cx, &mut buf)) { + let n = match ready!(poll_read_buf(Pin::new(&mut f), cx, &mut buf)) { Ok(n) => n as u64, Err(err) => { tracing::debug!("file read error: {}", err); @@ -524,3 +525,33 @@ mod tests { assert_eq!(buf.capacity(), cap); } } + +fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.bytes_mut(); + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} diff --git a/src/filters/sse.rs b/src/filters/sse.rs index 8c4eda58f..9d5849230 100644 --- a/src/filters/sse.rs +++ b/src/filters/sse.rs @@ -54,7 +54,7 @@ use hyper::Body; use pin_project::pin_project; use serde::Serialize; use serde_json; -use tokio::time::{self, Delay}; +use tokio::time::{self, Sleep}; use self::sealed::{ BoxedServerSentEvent, EitherServerSentEvent, SseError, SseField, SseFormat, SseWrapper, @@ -467,7 +467,7 @@ impl KeepAlive { S::Ok: ServerSentEvent + Send, S::Error: StdError + Send + Sync + 'static, { - let alive_timer = time::delay_for(self.max_interval); + let alive_timer = time::sleep(self.max_interval); SseKeepAlive { event_stream, comment_text: self.comment_text, @@ -484,7 +484,7 @@ struct SseKeepAlive { event_stream: S, comment_text: Cow<'static, str>, max_interval: Duration, - alive_timer: Delay, + alive_timer: Sleep, } #[doc(hidden)] @@ -505,7 +505,7 @@ where let max_interval = keep_interval .into() .unwrap_or_else(|| Duration::from_secs(15)); - let alive_timer = time::delay_for(max_interval); + let alive_timer = time::sleep(max_interval); SseKeepAlive { event_stream, comment_text: Cow::Borrowed(""), diff --git a/src/filters/ws.rs b/src/filters/ws.rs index bbd1836d2..ff752c7f6 100644 --- a/src/filters/ws.rs +++ b/src/filters/ws.rs @@ -134,10 +134,7 @@ where fn into_response(self) -> Response { let on_upgrade = self.on_upgrade; let config = self.ws.config; - let fut = self - .ws - .body - .on_upgrade() + let fut = hyper::upgrade::on(Response::new(self.ws.body)) .and_then(move |upgraded| { tracing::trace!("websocket upgrade complete"); WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) diff --git a/src/server.rs b/src/server.rs index 14024e5b6..6ff8d9580 100644 --- a/src/server.rs +++ b/src/server.rs @@ -416,26 +416,76 @@ where // TLS config methods /// Specify the file path to read the private key. + /// + /// *This function requires the `"tls"` feature.* pub fn key_path(self, path: impl AsRef) -> Self { self.with_tls(|tls| tls.key_path(path)) } /// Specify the file path to read the certificate. + /// + /// *This function requires the `"tls"` feature.* pub fn cert_path(self, path: impl AsRef) -> Self { self.with_tls(|tls| tls.cert_path(path)) } + /// Specify the file path to read the trust anchor for optional client authentication. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_optional_path(self, path: impl AsRef) -> Self { + self.with_tls(|tls| tls.client_auth_optional_path(path)) + } + + /// Specify the file path to read the trust anchor for required client authentication. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_required_path(self, path: impl AsRef) -> Self { + self.with_tls(|tls| tls.client_auth_required_path(path)) + } + /// Specify the in-memory contents of the private key. + /// + /// *This function requires the `"tls"` feature.* pub fn key(self, key: impl AsRef<[u8]>) -> Self { self.with_tls(|tls| tls.key(key.as_ref())) } /// Specify the in-memory contents of the certificate. + /// + /// *This function requires the `"tls"` feature.* pub fn cert(self, cert: impl AsRef<[u8]>) -> Self { self.with_tls(|tls| tls.cert(cert.as_ref())) } + /// Specify the in-memory contents of the trust anchor for optional client authentication. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_optional(self, trust_anchor: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.client_auth_optional(trust_anchor.as_ref())) + } + + /// Specify the in-memory contents of the trust anchor for required client authentication. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + /// + /// *This function requires the `"tls"` feature.* + pub fn client_auth_required(self, trust_anchor: impl AsRef<[u8]>) -> Self { + self.with_tls(|tls| tls.client_auth_required(trust_anchor.as_ref())) + } + /// Specify the DER-encoded OCSP response. + /// + /// *This function requires the `"tls"` feature.* pub fn ocsp_resp(self, resp: impl AsRef<[u8]>) -> Self { self.with_tls(|tls| tls.ocsp_resp(resp.as_ref())) } diff --git a/src/test.rs b/src/test.rs index 50fc0e388..f2ff104de 100644 --- a/src/test.rs +++ b/src/test.rs @@ -515,7 +515,7 @@ impl WsBuilder { let upgrade = ::hyper::Client::builder() .build(AddrConnect(addr)) .request(req) - .and_then(|res| res.into_body().on_upgrade()); + .and_then(|res| hyper::upgrade::on(res)); let upgraded = match upgrade.await { Ok(up) => { diff --git a/src/tls.rs b/src/tls.rs index c68887c68..44cb7c13c 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -6,14 +6,17 @@ use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use futures::ready; use hyper::server::accept::Accept; use hyper::server::conn::{AddrIncoming, AddrStream}; use crate::transport::Transport; -use tokio_rustls::rustls::{NoClientAuth, ServerConfig, TLSError}; +use tokio_rustls::rustls::{ + AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth, + RootCertStore, ServerConfig, TLSError, +}; /// Represents errors that can occur building the TlsConfig #[derive(Debug)] @@ -46,10 +49,21 @@ impl std::fmt::Display for TlsConfigError { impl std::error::Error for TlsConfigError {} +/// Tls client authentication configuration. +pub(crate) enum TlsClientAuth { + /// No client auth. + Off, + /// Allow any anonymous or authenticated client. + Optional(Box), + /// Allow any authenticated client. + Required(Box), +} + /// Builder to set the configuration for the Tls server. pub(crate) struct TlsConfigBuilder { cert: Box, key: Box, + client_auth: TlsClientAuth, ocsp_resp: Vec, } @@ -65,6 +79,7 @@ impl TlsConfigBuilder { TlsConfigBuilder { key: Box::new(io::empty()), cert: Box::new(io::empty()), + client_auth: TlsClientAuth::Off, ocsp_resp: Vec::new(), } } @@ -99,6 +114,52 @@ impl TlsConfigBuilder { self } + /// Sets the trust anchor for optional Tls client authentication via file path. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_optional_path(mut self, path: impl AsRef) -> Self { + let file = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self.client_auth = TlsClientAuth::Optional(file); + self + } + + /// Sets the trust anchor for optional Tls client authentication via bytes slice. + /// + /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any + /// of the `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self { + let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); + self.client_auth = TlsClientAuth::Optional(cursor); + self + } + + /// Sets the trust anchor for required Tls client authentication via file path. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_required_path(mut self, path: impl AsRef) -> Self { + let file = Box::new(LazyFile { + path: path.as_ref().into(), + file: None, + }); + self.client_auth = TlsClientAuth::Required(file); + self + } + + /// Sets the trust anchor for required Tls client authentication via bytes slice. + /// + /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the + /// `client_auth_` methods, then client authentication is disabled by default. + pub(crate) fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self { + let cursor = Box::new(Cursor::new(Vec::from(trust_anchor))); + self.client_auth = TlsClientAuth::Required(cursor); + self + } + /// sets the DER-encoded OCSP response pub(crate) fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self { self.ocsp_resp = Vec::from(ocsp_resp); @@ -142,7 +203,29 @@ impl TlsConfigBuilder { } }; - let mut config = ServerConfig::new(NoClientAuth::new()); + fn read_trust_anchor( + trust_anchor: Box, + ) -> Result { + let mut reader = BufReader::new(trust_anchor); + let mut store = RootCertStore::empty(); + if let Ok((0, _)) | Err(()) = store.add_pem_file(&mut reader) { + Err(TlsConfigError::CertParseError) + } else { + Ok(store) + } + } + + let client_auth = match self.client_auth { + TlsClientAuth::Off => NoClientAuth::new(), + TlsClientAuth::Optional(trust_anchor) => { + AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) + } + TlsClientAuth::Required(trust_anchor) => { + AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?) + } + }; + + let mut config = ServerConfig::new(client_auth); config .set_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new()) .map_err(|err| TlsConfigError::InvalidKey(err))?; @@ -212,8 +295,8 @@ impl AsyncRead for TlsStream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { let pin = self.get_mut(); match pin.state { State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { diff --git a/src/transport.rs b/src/transport.rs index 42f0431df..be553e706 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use hyper::server::conn::AddrStream; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub trait Transport: AsyncRead + AsyncWrite { fn remote_addr(&self) -> Option; @@ -22,8 +22,8 @@ impl AsyncRead for LiftIo { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf<'_>, + ) -> Poll> { Pin::new(&mut self.get_mut().0).poll_read(cx, buf) } }