diff --git a/src/proto/body.rs b/src/proto/body.rs index 3f89fdcfc0..5a6c773d7e 100644 --- a/src/proto/body.rs +++ b/src/proto/body.rs @@ -1,6 +1,6 @@ use bytes::Bytes; -use futures::{Poll, Stream}; -use futures::sync::mpsc; +use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use futures::sync::{mpsc, oneshot}; use tokio_proto; use std::borrow::Cow; @@ -12,20 +12,36 @@ pub type BodySender = mpsc::Sender>; /// A `Stream` for `Chunk`s used in requests and responses. #[must_use = "streams do nothing unless polled"] #[derive(Debug)] -pub struct Body(TokioBody); +pub struct Body(Inner); + +#[derive(Debug)] +enum Inner { + Tokio(TokioBody), + Hyper { + close_tx: oneshot::Sender<()>, + rx: mpsc::Receiver>, + } +} + +//pub(crate) +#[derive(Debug)] +pub struct ChunkSender { + close_rx: oneshot::Receiver<()>, + tx: BodySender, +} impl Body { /// Return an empty body stream #[inline] pub fn empty() -> Body { - Body(TokioBody::empty()) + Body(Inner::Tokio(TokioBody::empty())) } /// Return a body stream with an associated sender half #[inline] pub fn pair() -> (mpsc::Sender>, Body) { let (tx, rx) = TokioBody::pair(); - let rx = Body(rx); + let rx = Body(Inner::Tokio(rx)); (tx, rx) } } @@ -43,7 +59,51 @@ impl Stream for Body { #[inline] fn poll(&mut self) -> Poll, ::Error> { - self.0.poll() + match self.0 { + Inner::Tokio(ref mut rx) => rx.poll(), + Inner::Hyper { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") { + Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))), + Async::Ready(Some(Err(err))) => Err(err), + Async::Ready(None) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + }, + } + } +} + +//pub(crate) +pub fn channel() -> (ChunkSender, Body) { + let (tx, rx) = mpsc::channel(0); + let (close_tx, close_rx) = oneshot::channel(); + + let tx = ChunkSender { + close_rx: close_rx, + tx: tx, + }; + let rx = Body(Inner::Hyper { + close_tx: close_tx, + rx: rx, + }); + + (tx, rx) +} + +impl ChunkSender { + pub fn poll_ready(&mut self) -> Poll<(), ()> { + match self.close_rx.poll() { + Ok(Async::Ready(())) | Err(_) => return Err(()), + Ok(Async::NotReady) => (), + } + + self.tx.poll_ready().map_err(|_| ()) + } + + pub fn start_send(&mut self, msg: Result) -> StartSend<(), ()> { + match self.tx.start_send(msg) { + Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), + Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(())), + Err(_) => Err(()), + } } } @@ -52,7 +112,14 @@ impl Stream for Body { impl From for tokio_proto::streaming::Body { #[inline] fn from(b: Body) -> tokio_proto::streaming::Body { - b.0 + match b.0 { + Inner::Tokio(b) => b, + Inner::Hyper { close_tx, rx } => { + warn!("converting hyper::Body into a tokio_proto Body is deprecated"); + ::std::mem::forget(close_tx); + rx.into() + } + } } } @@ -61,42 +128,42 @@ impl From for tokio_proto::streaming::Body { impl From> for Body { #[inline] fn from(tokio_body: tokio_proto::streaming::Body) -> Body { - Body(tokio_body) + Body(Inner::Tokio(tokio_body)) } } impl From>> for Body { #[inline] fn from(src: mpsc::Receiver>) -> Body { - Body(src.into()) + TokioBody::from(src).into() } } impl From for Body { #[inline] fn from (chunk: Chunk) -> Body { - Body(TokioBody::from(chunk)) + TokioBody::from(chunk).into() } } impl From for Body { #[inline] fn from (bytes: Bytes) -> Body { - Body(TokioBody::from(Chunk::from(bytes))) + Body::from(TokioBody::from(Chunk::from(bytes))) } } impl From> for Body { #[inline] fn from (vec: Vec) -> Body { - Body(TokioBody::from(Chunk::from(vec))) + Body::from(TokioBody::from(Chunk::from(vec))) } } impl From<&'static [u8]> for Body { #[inline] fn from (slice: &'static [u8]) -> Body { - Body(TokioBody::from(Chunk::from(slice))) + Body::from(TokioBody::from(Chunk::from(slice))) } } @@ -113,14 +180,14 @@ impl From> for Body { impl From for Body { #[inline] fn from (s: String) -> Body { - Body(TokioBody::from(Chunk::from(s.into_bytes()))) + Body::from(TokioBody::from(Chunk::from(s.into_bytes()))) } } impl From<&'static str> for Body { #[inline] fn from(slice: &'static str) -> Body { - Body(TokioBody::from(Chunk::from(slice.as_bytes()))) + Body::from(TokioBody::from(Chunk::from(slice.as_bytes()))) } } diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index 0133f7989d..cdd9ae4971 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -1,6 +1,6 @@ use std::io; -use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use futures::{Async, AsyncSink, Future, Poll, Stream}; use futures::sync::{mpsc, oneshot}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_service::Service; @@ -11,7 +11,7 @@ use ::StatusCode; pub struct Dispatcher { conn: Conn, dispatch: D, - body_tx: Option, + body_tx: Option, body_rx: Option, is_closing: bool, } @@ -22,6 +22,7 @@ pub trait Dispatch { type RecvItem; fn poll_msg(&mut self) -> Poll)>, ::Error>; fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Option)>) -> ::Result<()>; + fn poll_ready(&mut self) -> Poll<(), ()>; fn should_poll(&self) -> bool; } @@ -70,10 +71,22 @@ where if self.is_closing { return Ok(Async::Ready(())); } else if self.conn.can_read_head() { + // can dispatch receive, or does it still care about, an incoming message? + match self.dispatch.poll_ready() { + Ok(Async::Ready(())) => (), + Ok(Async::NotReady) => unreachable!("dispatch not ready when conn is"), + Err(()) => { + trace!("dispatch no longer receiving messages"); + self.is_closing = true; + return Ok(Async::Ready(())); + } + } + // dispatch is ready for a message, try to read one match self.conn.read_head() { Ok(Async::Ready(Some((head, has_body)))) => { let body = if has_body { - let (tx, rx) = super::Body::pair(); + let (mut tx, rx) = super::body::channel(); + let _ = tx.poll_ready(); // register this task if rx is dropped self.body_tx = Some(tx); Some(rx) } else { @@ -111,6 +124,8 @@ where self.conn.close_read(); return Ok(Async::Ready(())); } + // else the conn body is done, and user dropped, + // so everything is fine! } } if can_read_body { @@ -133,7 +148,7 @@ where } }, Ok(Async::Ready(None)) => { - let _ = body.close(); + // just drop, the body will close automatically }, Ok(Async::NotReady) => { self.body_tx = Some(body); @@ -144,7 +159,7 @@ where } } } else { - let _ = body.close(); + // just drop, the body will close automatically } } else if !T::should_read_first() { self.conn.try_empty_read()?; @@ -305,6 +320,14 @@ where Ok(()) } + fn poll_ready(&mut self) -> Poll<(), ()> { + if self.in_flight.is_some() { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(())) + } + } + fn should_poll(&self) -> bool { self.in_flight.is_some() } @@ -333,9 +356,18 @@ where fn poll_msg(&mut self) -> Poll)>, ::Error> { match self.rx.poll() { - Ok(Async::Ready(Some(ClientMsg::Request(head, body, cb)))) => { - self.callback = Some(cb); - Ok(Async::Ready(Some((head, body)))) + Ok(Async::Ready(Some(ClientMsg::Request(head, body, mut cb)))) => { + // check that future hasn't been canceled already + match cb.poll_cancel().expect("poll_cancel cannot error") { + Async::Ready(()) => { + trace!("request canceled"); + Ok(Async::Ready(None)) + }, + Async::NotReady => { + self.callback = Some(cb); + Ok(Async::Ready(Some((head, body)))) + } + } }, Ok(Async::Ready(Some(ClientMsg::Close))) | Ok(Async::Ready(None)) => { @@ -370,6 +402,20 @@ where } } + fn poll_ready(&mut self) -> Poll<(), ()> { + match self.callback { + Some(ref mut cb) => match cb.poll_cancel() { + Ok(Async::Ready(())) => { + trace!("callback receiver has dropped"); + Err(()) + }, + Ok(Async::NotReady) => Ok(Async::Ready(())), + Err(_) => unreachable!("oneshot poll_cancel cannot error"), + }, + None => Err(()), + } + } + fn should_poll(&self) -> bool { self.callback.is_none() } diff --git a/tests/client.rs b/tests/client.rs index f86b8f78f0..16b7ab00f0 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -704,6 +704,97 @@ mod dispatch_impl { assert_eq!(closes.load(Ordering::Relaxed), 1); } + #[test] + fn drop_response_future_closes_in_progress_connection() { + let _ = pretty_env_logger::init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + + let (tx1, rx1) = oneshot::channel(); + let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + // we never write a response head + // simulates a slow server operation + let _ = tx1.send(()); + + // prevent this thread from closing until end of test, so the connection + // stays open and idle until Client is dropped + let _ = client_drop_rx.wait(); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let res = { + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) + .no_proto() + .build(&handle); + client.get(uri) + }; + + //let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.select2(rx1)).unwrap(); + // res now dropped + core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + } + + #[test] + fn drop_response_body_closes_in_progress_connection() { + let _ = pretty_env_logger::init(); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let closes = Arc::new(AtomicUsize::new(0)); + + let (tx1, rx1) = oneshot::channel(); + let (_client_drop_tx, client_drop_rx) = oneshot::channel::<()>(); + + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + write!(sock, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n").expect("write head"); + let _ = tx1.send(()); + + // prevent this thread from closing until end of test, so the connection + // stays open and idle until Client is dropped + let _ = client_drop_rx.wait(); + }); + + let uri = format!("http://{}/a", addr).parse().unwrap(); + + let res = { + let client = Client::configure() + .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) + .no_proto() + .build(&handle); + // notably, havent read body yet + client.get(uri) + }; + + let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); + core.run(res.join(rx).map(|r| r.0)).unwrap(); + core.run(Timeout::new(Duration::from_millis(100), &handle).unwrap()).unwrap(); + + assert_eq!(closes.load(Ordering::Relaxed), 1); + } + #[test] fn no_keep_alive_closes_connection() { // https://github.com/hyperium/hyper/issues/1383