diff --git a/examples/hello.rs b/examples/hello.rs index 7740e8ee65..3a1d865a35 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -10,7 +10,6 @@ static PHRASE: &'static [u8] = b"Hello World!"; fn main() { pretty_env_logger::init(); - let addr = ([127, 0, 0, 1], 3000).into(); // new_service is run for each connection, creating a 'service' diff --git a/src/client/connect.rs b/src/client/connect.rs index d347426e3c..ab744d67c5 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -7,6 +7,7 @@ //! - The [`Connect`](Connect) trait and related types to build custom connectors. use std::error::Error as StdError; use std::mem; +use std::net::SocketAddr; use bytes::{BufMut, BytesMut}; use futures::Future; @@ -42,10 +43,11 @@ pub struct Destination { /// /// This can be used to inform recipients about things like if ALPN /// was used, or if connected to an HTTP proxy. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Connected { //alpn: Alpn, pub(super) is_proxied: bool, + pub(super) remote_addr: Option, } /*TODO: when HTTP1 Upgrades to H2 are added, this will be needed @@ -234,8 +236,8 @@ impl Connected { /// Create new `Connected` type with empty metadata. pub fn new() -> Connected { Connected { - //alpn: Alpn::Http1, is_proxied: false, + remote_addr: None, } } @@ -251,6 +253,14 @@ impl Connected { self } + /// Set the remote address of the connected transport. + /// + /// Default is `None`. + pub fn remote_addr(mut self, addr: SocketAddr) -> Connected { + self.remote_addr = Some(addr); + self + } + /* /// Set that the connected transport negotiated HTTP/2 as it's /// next protocol. @@ -640,16 +650,12 @@ mod http { } }, State::Resolving(ref mut future, local_addr) => { - match try!(future.poll()) { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(addrs) => { - state = State::Connecting(ConnectingTcp { - addrs: addrs, - local_addr: local_addr, - current: None, - }) - } - }; + let addrs = try_ready!(future.poll()); + state = State::Connecting(ConnectingTcp { + addrs: addrs, + local_addr: local_addr, + current: None, + }); }, State::Connecting(ref mut c) => { let sock = try_ready!(c.poll(&self.handle)); @@ -659,8 +665,11 @@ mod http { } sock.set_nodelay(self.nodelay)?; + let remote_addr = sock.peer_addr()?; + let connected = Connected::new() + .remote_addr(remote_addr); - return Ok(Async::Ready((sock, Connected::new()))); + return Ok(Async::Ready((sock, connected))); }, State::Error(ref mut e) => return Err(e.take().expect("polled more than once")), } diff --git a/src/client/mod.rs b/src/client/mod.rs index cc4137d292..0128894e80 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -92,7 +92,7 @@ use http::uri::Scheme; use body::{Body, Payload}; use common::Exec; use common::lazy as hyper_lazy; -use self::connect::{Connect, Destination}; +use self::connect::{Connect, Connected, Destination}; use self::pool::{Pool, Poolable, Reservation}; #[cfg(feature = "runtime")] pub use self::connect::HttpConnector; @@ -113,6 +113,7 @@ pub struct Client { h1_title_case_headers: bool, pool: Pool>, retry_canceled_requests: bool, + set_conn_info: bool, set_host: bool, ver: Ver, } @@ -304,7 +305,7 @@ where C: Connect + Sync + 'static, }) .map(move |tx| { pool.pooled(connecting, PoolClient { - is_proxied: connected.is_proxied, + conn_info: connected, tx: match ver { Ver::Http1 => PoolTx::Http1(tx), Ver::Http2 => PoolTx::Http2(tx.into_http2()), @@ -380,6 +381,8 @@ where C: Connect + Sync + 'static, } }); + let set_conn_info = self.set_conn_info; + let executor = self.executor.clone(); let resp = race.and_then(move |mut pooled| { let conn_reused = pooled.is_reused(); @@ -387,7 +390,7 @@ where C: Connect + Sync + 'static, // CONNECT always sends origin-form, so check it first... if req.method() == &Method::CONNECT { authority_form(req.uri_mut()); - } else if pooled.is_proxied { + } else if pooled.conn_info.is_proxied { absolute_form(req.uri_mut()); } else { origin_form(req.uri_mut()); @@ -401,6 +404,17 @@ where C: Connect + Sync + 'static, let fut = pooled.send_request_retryable(req); + let conn_info = pooled.conn_info.clone(); + let fut = fut.map(move |mut res| { + if set_conn_info { + let info = ::ext::ConnectionInfo { + remote_addr: conn_info.remote_addr, + }; + info.set(&mut res); + } + res + }); + // As of futures@0.1.21, there is a race condition in the mpsc // channel, such that sending when the receiver is closing can // result in the message being stuck inside the queue. It won't @@ -498,6 +512,7 @@ impl Clone for Client { h1_title_case_headers: self.h1_title_case_headers, pool: self.pool.clone(), retry_canceled_requests: self.retry_canceled_requests, + set_conn_info: self.set_conn_info, set_host: self.set_host, ver: self.ver, } @@ -584,7 +599,7 @@ where } struct PoolClient { - is_proxied: bool, + conn_info: Connected, tx: PoolTx, } @@ -644,17 +659,17 @@ where match self.tx { PoolTx::Http1(tx) => { Reservation::Unique(PoolClient { - is_proxied: self.is_proxied, + conn_info: self.conn_info, tx: PoolTx::Http1(tx), }) }, PoolTx::Http2(tx) => { let b = PoolClient { - is_proxied: self.is_proxied, + conn_info: self.conn_info.clone(), tx: PoolTx::Http2(tx.clone()), }; let a = PoolClient { - is_proxied: self.is_proxied, + conn_info: self.conn_info, tx: PoolTx::Http2(tx), }; Reservation::Shared(a, b) @@ -751,6 +766,7 @@ pub struct Builder { //TODO: make use of max_idle config max_idle: usize, retry_canceled_requests: bool, + set_conn_info: bool, set_host: bool, ver: Ver, } @@ -765,6 +781,7 @@ impl Default for Builder { h1_title_case_headers: false, max_idle: 5, retry_canceled_requests: true, + set_conn_info: false, set_host: true, ver: Ver::Http1, } @@ -851,6 +868,16 @@ impl Builder { self } + + /// Set whether to automatically add [`ConnectionInfo`](::ext::ConnectionInfo) + /// to `Response`s. + /// + /// Default is `false`. + pub fn set_conn_info(&mut self, val: bool) -> &mut Self { + self.set_conn_info = val; + self + } + /// Set whether to automatically add the `Host` header to requests. /// /// If true, and a request does not include a `Host` header, one will be @@ -902,6 +929,7 @@ impl Builder { h1_title_case_headers: self.h1_title_case_headers, pool: Pool::new(self.keep_alive, self.keep_alive_timeout, &self.exec), retry_canceled_requests: self.retry_canceled_requests, + set_conn_info: self.set_conn_info, set_host: self.set_host, ver: self.ver, } diff --git a/src/ext/conn_info.rs b/src/ext/conn_info.rs new file mode 100644 index 0000000000..77ef63312c --- /dev/null +++ b/src/ext/conn_info.rs @@ -0,0 +1,57 @@ +use std::net::SocketAddr; + +use super::Ext; + +/// dox +#[derive(Debug)] +pub struct ConnectionInfo { + pub(crate) remote_addr: Option, +} + +// The private type that gets put into extensions() +// +// The reason for the public and private types is to, for now, prevent +// a public API contract that crates could depend on. If the public type +// were inserted into the Extensions directly, a user could depend on +// `req.extensions().get::()`, and it's not clear that +// we want this contract yet. +#[derive(Copy, Clone, Default)] +struct ConnInfo { + remote_addr: Option, +} + +impl ConnectionInfo { + /// dox + pub fn get(extend: &E) -> ConnectionInfo + where + E: Ext, + { + let info = extend + .ext() + .get::() + .map(|&info| info) + .unwrap_or_default(); + + ConnectionInfo { + remote_addr: info.remote_addr, + } + } + + /// dox + pub(crate) fn set(self, extend: &mut E) + where + E: Ext, + { + let info = ConnInfo { + remote_addr: self.remote_addr, + }; + + extend.ext_mut().insert(info); + } + + /// dox + pub fn remote_addr(&self) -> Option { + self.remote_addr + } +} + diff --git a/src/ext/mod.rs b/src/ext/mod.rs new file mode 100644 index 0000000000..41ea9ab2f1 --- /dev/null +++ b/src/ext/mod.rs @@ -0,0 +1,44 @@ +//! dox +use http::{Extensions, Request, Response}; + +use self::sealed::{Ext, Sealed}; + +mod conn_info; + +pub use self::conn_info::ConnectionInfo; + + +mod sealed { + use http::Extensions; + + pub trait Sealed { + fn ext(&self) -> &Extensions; + fn ext_mut(&mut self) -> &mut Extensions; + } + + pub trait Ext: Sealed {} +} + +impl Sealed for Request { + fn ext(&self) -> &Extensions { + self.extensions() + } + + fn ext_mut(&mut self) -> &mut Extensions { + self.extensions_mut() + } +} + +impl Ext for Request {} + +impl Sealed for Response { + fn ext(&self) -> &Extensions { + self.extensions() + } + + fn ext_mut(&mut self) -> &mut Extensions { + self.extensions_mut() + } +} + +impl Ext for Response {} diff --git a/src/lib.rs b/src/lib.rs index 401d925c6b..64255cbb12 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ mod mock; pub mod body; pub mod client; pub mod error; +pub mod ext; mod headers; mod proto; pub mod server; diff --git a/tests/client.rs b/tests/client.rs index 6c90597f00..6f18ea421c 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -90,6 +90,7 @@ macro_rules! test { expected: $server_expected, reply: $server_reply, client: + set_conn_info: false, set_host: $set_host, title_case_headers: false, request: @@ -104,12 +105,53 @@ macro_rules! test { body: $response_body, } ); + ( name: $name:ident, server: expected: $server_expected:expr, reply: $server_reply:expr, client: + set_conn_info: $set_conn_info:expr, + request: + method: $client_method:ident, + url: $client_url:expr, + headers: { $($request_header_name:expr => $request_header_val:expr,)* }, + body: $request_body:expr, + + response: + status: $client_status:ident, + headers: { $($response_header_name:expr => $response_header_val:expr,)* }, + body: $response_body:expr, + ) => ( + test! { + name: $name, + server: + expected: $server_expected, + reply: $server_reply, + client: + set_conn_info: $set_conn_info, + set_host: true, + title_case_headers: false, + request: + method: $client_method, + url: $client_url, + headers: { $($request_header_name => $request_header_val,)* }, + body: $request_body, + + response: + status: $client_status, + headers: { $($response_header_name => $response_header_val,)* }, + body: $response_body, + } + ); + ( + name: $name:ident, + server: + expected: $server_expected:expr, + reply: $server_reply:expr, + client: + set_conn_info: $set_conn_info:expr, set_host: $set_host:expr, title_case_headers: $title_case_headers:expr, request: @@ -136,6 +178,7 @@ macro_rules! test { expected: $server_expected, reply: $server_reply, client: + set_conn_info: $set_conn_info, set_host: $set_host, title_case_headers: $title_case_headers, request: @@ -188,6 +231,7 @@ macro_rules! test { expected: $server_expected, reply: $server_reply, client: + set_conn_info: false, set_host: true, title_case_headers: false, request: @@ -214,6 +258,7 @@ macro_rules! test { expected: $server_expected:expr, reply: $server_reply:expr, client: + set_conn_info: $set_conn_info:expr, set_host: $set_host:expr, title_case_headers: $title_case_headers:expr, request: @@ -229,6 +274,7 @@ macro_rules! test { let connector = ::hyper::client::HttpConnector::new_with_handle(1, Handle::default()); let client = Client::builder() .set_host($set_host) + .set_conn_info($set_conn_info) .http1_title_case_headers($title_case_headers) .build(connector); @@ -274,7 +320,20 @@ macro_rules! test { let rx = rx.expect("thread panicked"); - rt.block_on(res.join(rx).map(|r| r.0)) + let result = rt.block_on(res.join(rx).map(|r| r.0)); + + result.map(|resp| { + let info = ::hyper::ext::ConnectionInfo::get(&resp); + + let expected = if $set_conn_info { + Some(addr) + } else { + None + }; + assert_eq!(info.remote_addr(), expected); + + resp + }) }); } @@ -678,6 +737,7 @@ test! { ", client: + set_conn_info: false, set_host: true, title_case_headers: true, request: @@ -693,6 +753,34 @@ test! { body: None, } +test! { + name: client_with_conn_info, + + server: + expected: "\ + GET / HTTP/1.1\r\n\ + host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 0\r\n\ + \r\n\ + ", + + client: + set_conn_info: true, + request: + method: GET, + url: "http://{addr}/", + headers: {}, + body: None, + response: + status: OK, + headers: {}, + body: None, +} + mod dispatch_impl { use super::*; use std::io::{self, Read, Write};