From f36c6b255f0f7e69d6ee8548ff7326acf6ff5847 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 27 Apr 2016 11:20:07 -0700 Subject: [PATCH] feat(client): proper proxy and tunneling in Client Closes #774 --- examples/client.rs | 15 ++- src/client/mod.rs | 96 ++++++++++++----- src/client/pool.rs | 1 + src/client/proxy.rs | 240 +++++++++++++++++++++++++++++++++++++++++ src/client/request.rs | 7 +- src/http/h1.rs | 25 ++--- src/http/message.rs | 5 + src/mock.rs | 12 ++- src/net.rs | 64 ++++++++--- src/server/response.rs | 5 +- 10 files changed, 406 insertions(+), 64 deletions(-) create mode 100644 src/client/proxy.rs diff --git a/examples/client.rs b/examples/client.rs index 6d6a9384c7..5467f52c7d 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -20,7 +20,20 @@ fn main() { } }; - let client = Client::new(); + let client = match env::var("HTTP_PROXY") { + Ok(mut proxy) => { + // parse the proxy, message if it doesn't make sense + let mut port = 80; + if let Some(colon) = proxy.rfind(':') { + port = proxy[colon + 1..].parse().unwrap_or_else(|e| { + panic!("HTTP_PROXY is malformed: {:?}, port parse error: {}", proxy, e); + }); + proxy.truncate(colon); + } + Client::with_http_proxy(proxy, port) + }, + _ => Client::new() + }; let mut res = client.get(&*url) .header(Connection::close()) diff --git a/src/client/mod.rs b/src/client/mod.rs index 4ad5c2d5fd..c29fdc9f29 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -71,10 +71,12 @@ use method::Method; use net::{NetworkConnector, NetworkStream}; use Error; +use self::proxy::tunnel; pub use self::pool::Pool; pub use self::request::Request; pub use self::response::Response; +mod proxy; pub mod pool; pub mod request; pub mod response; @@ -90,7 +92,7 @@ pub struct Client { redirect_policy: RedirectPolicy, read_timeout: Option, write_timeout: Option, - proxy: Option<(Cow<'static, str>, Cow<'static, str>, u16)> + proxy: Option<(Cow<'static, str>, u16)> } impl fmt::Debug for Client { @@ -116,6 +118,15 @@ impl Client { Client::with_connector(Pool::new(config)) } + pub fn with_http_proxy(host: H, port: u16) -> Client + where H: Into> { + let host = host.into(); + let proxy = tunnel((host.clone(), port)); + let mut client = Client::with_connector(Pool::with_connector(Default::default(), proxy)); + client.proxy = Some((host, port)); + client + } + /// Create a new client with a specific connector. pub fn with_connector(connector: C) -> Client where C: NetworkConnector + Send + Sync + 'static, S: NetworkStream + Send { @@ -148,12 +159,6 @@ impl Client { self.write_timeout = dur; } - /// Set a proxy for requests of this Client. - pub fn set_proxy(&mut self, scheme: S, host: H, port: u16) - where S: Into>, H: Into> { - self.proxy = Some((scheme.into(), host.into(), port)); - } - /// Build a Get request. pub fn get(&self, url: U) -> RequestBuilder { self.request(Method::Get, url) @@ -271,13 +276,12 @@ impl<'a> RequestBuilder<'a> { loop { let mut req = { - let (scheme, host, port) = match client.proxy { - Some(ref proxy) => (proxy.0.as_ref(), proxy.1.as_ref(), proxy.2), - None => { - let hp = try!(get_host_and_port(&url)); - (url.scheme(), hp.0, hp.1) - } - }; + let (host, port) = try!(get_host_and_port(&url)); + let mut message = try!(client.protocol.new_message(&host, port, url.scheme())); + if url.scheme() == "http" && client.proxy.is_some() { + message.set_proxied(true); + } + let mut headers = match headers { Some(ref headers) => headers.clone(), None => Headers::new(), @@ -286,7 +290,6 @@ impl<'a> RequestBuilder<'a> { hostname: host.to_owned(), port: Some(port), }); - let message = try!(client.protocol.new_message(&host, port, scheme)); Request::with_headers_and_message(method.clone(), url.clone(), headers, message) }; @@ -460,6 +463,7 @@ impl Default for RedirectPolicy { } } + fn get_host_and_port(url: &Url) -> ::Result<(&str, u16)> { let host = match url.host_str() { Some(host) => host, @@ -479,8 +483,9 @@ mod tests { use std::io::Read; use header::Server; use http::h1::Http11Message; - use mock::{MockStream}; + use mock::{MockStream, MockSsl}; use super::{Client, RedirectPolicy}; + use super::proxy::Proxy; use super::pool::Pool; use url::Url; @@ -505,24 +510,61 @@ mod tests { #[test] fn test_proxy() { use super::pool::PooledStream; + type MessageStream = PooledStream>; mock_connector!(ProxyConnector { b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" }); - let mut client = Client::with_connector(Pool::with_connector(Default::default(), ProxyConnector)); - client.set_proxy("http", "example.proxy", 8008); + let tunnel = Proxy { + connector: ProxyConnector, + proxy: ("example.proxy".into(), 8008), + ssl: MockSsl, + }; + let mut client = Client::with_connector(Pool::with_connector(Default::default(), tunnel)); + client.proxy = Some(("example.proxy".into(), 8008)); let mut dump = vec![]; client.get("http://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap(); - { - let box_message = client.protocol.new_message("example.proxy", 8008, "http").unwrap(); - let message = box_message.downcast::().unwrap(); - let stream = message.into_inner().downcast::>().unwrap().into_inner(); - let s = ::std::str::from_utf8(&stream.write).unwrap(); - let request_line = "GET http://127.0.0.1/foo/bar HTTP/1.1\r\n"; - assert_eq!(&s[..request_line.len()], request_line); - assert!(s.contains("Host: example.proxy:8008\r\n")); - } + let box_message = client.protocol.new_message("127.0.0.1", 80, "http").unwrap(); + let message = box_message.downcast::().unwrap(); + let stream = message.into_inner().downcast::().unwrap().into_inner().into_normal().unwrap();; + + let s = ::std::str::from_utf8(&stream.write).unwrap(); + let request_line = "GET http://127.0.0.1/foo/bar HTTP/1.1\r\n"; + assert!(s.starts_with(request_line), "{:?} doesn't start with {:?}", s, request_line); + assert!(s.contains("Host: 127.0.0.1\r\n")); + } + + #[test] + fn test_proxy_tunnel() { + use super::pool::PooledStream; + type MessageStream = PooledStream>; + + mock_connector!(ProxyConnector { + b"HTTP/1.1 200 OK\r\n\r\n", + b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" + }); + let tunnel = Proxy { + connector: ProxyConnector, + proxy: ("example.proxy".into(), 8008), + ssl: MockSsl, + }; + let mut client = Client::with_connector(Pool::with_connector(Default::default(), tunnel)); + client.proxy = Some(("example.proxy".into(), 8008)); + let mut dump = vec![]; + client.get("https://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap(); + + let box_message = client.protocol.new_message("127.0.0.1", 443, "https").unwrap(); + let message = box_message.downcast::().unwrap(); + let stream = message.into_inner().downcast::().unwrap().into_inner().into_tunneled().unwrap(); + + let s = ::std::str::from_utf8(&stream.write).unwrap(); + let connect_line = "CONNECT 127.0.0.1:443 HTTP/1.1\r\nHost: 127.0.0.1:443\r\n\r\n"; + assert_eq!(&s[..connect_line.len()], connect_line); + let s = &s[connect_line.len()..]; + let request_line = "GET /foo/bar HTTP/1.1\r\n"; + assert_eq!(&s[..request_line.len()], request_line); + assert!(s.contains("Host: 127.0.0.1\r\n")); } #[test] diff --git a/src/client/pool.rs b/src/client/pool.rs index 2c44fb37c1..fd86b0cd8c 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -127,6 +127,7 @@ impl, S: NetworkStream + Send> NetworkConnector fo } /// A Stream that will try to be returned to the Pool when dropped. +#[derive(Debug)] pub struct PooledStream { inner: Option>, is_closed: bool, diff --git a/src/client/proxy.rs b/src/client/proxy.rs new file mode 100644 index 0000000000..923d12eaa3 --- /dev/null +++ b/src/client/proxy.rs @@ -0,0 +1,240 @@ +use std::borrow::Cow; +use std::io; +use std::net::{SocketAddr, Shutdown}; +use std::time::Duration; + +use method::Method; +use net::{NetworkConnector, HttpConnector, NetworkStream, SslClient}; + +#[cfg(all(feature = "openssl", not(feature = "security-framework")))] +pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { + Proxy { + connector: HttpConnector, + proxy: proxy, + ssl: Default::default() + } +} + +#[cfg(feature = "security-framework")] +pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { + Proxy { + connector: HttpConnector, + proxy: proxy, + ssl: Default::default() + } +} + +#[cfg(not(any(feature = "openssl", feature = "security-framework")))] +pub fn tunnel(proxy: (Cow<'static, str>, u16)) -> Proxy { + Proxy { + connector: HttpConnector, + proxy: proxy, + ssl: self::no_ssl::Plaintext, + } + +} + +pub struct Proxy +where C: NetworkConnector + Send + Sync + 'static, + C::Stream: NetworkStream + Send + Clone, + S: SslClient { + pub connector: C, + pub proxy: (Cow<'static, str>, u16), + pub ssl: S, +} + + +impl NetworkConnector for Proxy +where C: NetworkConnector + Send + Sync + 'static, + C::Stream: NetworkStream + Send + Clone, + S: SslClient { + type Stream = Proxied; + + fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { + use httparse; + use std::io::{Read, Write}; + use ::version::HttpVersion::Http11; + trace!("{:?} proxy for '{}://{}:{}'", self.proxy, scheme, host, port); + match scheme { + "http" => { + self.connector.connect(self.proxy.0.as_ref(), self.proxy.1, "http") + .map(Proxied::Normal) + }, + "https" => { + let mut stream = try!(self.connector.connect(self.proxy.0.as_ref(), self.proxy.1, "http")); + trace!("{:?} CONNECT {}:{}", self.proxy, host, port); + try!(write!(&mut stream, "{method} {host}:{port} {version}\r\nHost: {host}:{port}\r\n\r\n", + method=Method::Connect, host=host, port=port, version=Http11)); + try!(stream.flush()); + let mut buf = [0; 1024]; + let mut n = 0; + while n < buf.len() { + n += try!(stream.read(&mut buf[n..])); + let mut headers = [httparse::EMPTY_HEADER; 10]; + let mut res = httparse::Response::new(&mut headers); + if try!(res.parse(&buf[..n])).is_complete() { + let code = res.code.expect("complete parsing lost code"); + if code >= 200 && code < 300 { + trace!("CONNECT success = {:?}", code); + return self.ssl.wrap_client(stream, host) + .map(Proxied::Tunneled) + } else { + trace!("CONNECT response = {:?}", code); + return Err(::Error::Status); + } + } + } + Err(::Error::TooLarge) + }, + _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid scheme").into()) + } + } +} + +#[derive(Debug)] +pub enum Proxied { + Normal(T1), + Tunneled(T2) +} + +#[cfg(test)] +impl Proxied { + pub fn into_normal(self) -> Result { + match self { + Proxied::Normal(t1) => Ok(t1), + _ => Err(self) + } + } + + pub fn into_tunneled(self) -> Result { + match self { + Proxied::Tunneled(t2) => Ok(t2), + _ => Err(self) + } + } +} + +impl io::Read for Proxied { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + Proxied::Normal(ref mut t) => io::Read::read(t, buf), + Proxied::Tunneled(ref mut t) => io::Read::read(t, buf), + } + } +} + +impl io::Write for Proxied { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + match *self { + Proxied::Normal(ref mut t) => io::Write::write(t, buf), + Proxied::Tunneled(ref mut t) => io::Write::write(t, buf), + } + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + match *self { + Proxied::Normal(ref mut t) => io::Write::flush(t), + Proxied::Tunneled(ref mut t) => io::Write::flush(t), + } + } +} + +impl NetworkStream for Proxied { + #[inline] + fn peer_addr(&mut self) -> io::Result { + match *self { + Proxied::Normal(ref mut s) => s.peer_addr(), + Proxied::Tunneled(ref mut s) => s.peer_addr() + } + } + + #[inline] + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + match *self { + Proxied::Normal(ref inner) => inner.set_read_timeout(dur), + Proxied::Tunneled(ref inner) => inner.set_read_timeout(dur) + } + } + + #[inline] + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + match *self { + Proxied::Normal(ref inner) => inner.set_write_timeout(dur), + Proxied::Tunneled(ref inner) => inner.set_write_timeout(dur) + } + } + + #[inline] + fn close(&mut self, how: Shutdown) -> io::Result<()> { + match *self { + Proxied::Normal(ref mut s) => s.close(how), + Proxied::Tunneled(ref mut s) => s.close(how) + } + } +} + +#[cfg(not(any(feature = "openssl", feature = "security-framework")))] +mod no_ssl { + use std::io; + use std::net::{Shutdown, SocketAddr}; + use std::time::Duration; + + use net::{SslClient, NetworkStream}; + + pub struct Plaintext; + + #[derive(Clone)] + pub enum Void {} + + impl io::Read for Void { + #[inline] + fn read(&mut self, _buf: &mut [u8]) -> io::Result { + match *self {} + } + } + + impl io::Write for Void { + #[inline] + fn write(&mut self, _buf: &[u8]) -> io::Result { + match *self {} + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + match *self {} + } + } + + impl NetworkStream for Void { + #[inline] + fn peer_addr(&mut self) -> io::Result { + match *self {} + } + + #[inline] + fn set_read_timeout(&self, _dur: Option) -> io::Result<()> { + match *self {} + } + + #[inline] + fn set_write_timeout(&self, _dur: Option) -> io::Result<()> { + match *self {} + } + + #[inline] + fn close(&mut self, _how: Shutdown) -> io::Result<()> { + match *self {} + } + } + + impl SslClient for Plaintext { + type Stream = Void; + + fn wrap_client(&self, _stream: T, _host: &str) -> ::Result { + Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid scheme").into()) + } + } +} diff --git a/src/client/request.rs b/src/client/request.rs index db0fce0a94..8743373c5b 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -268,16 +268,15 @@ mod tests { #[test] fn test_proxy() { let url = Url::parse("http://example.dom").unwrap(); - let proxy_url = Url::parse("http://pro.xy").unwrap(); let mut req = Request::with_connector( - Get, proxy_url, &mut MockConnector + Get, url, &mut MockConnector ).unwrap(); - req.url = url; + req.message.set_proxied(true); let bytes = run_request(req); let s = from_utf8(&bytes[..]).unwrap(); let request_line = "GET http://example.dom/ HTTP/1.1"; assert_eq!(&s[..request_line.len()], request_line); - assert!(s.contains("Host: pro.xy")); + assert!(s.contains("Host: example.dom")); } #[test] diff --git a/src/http/h1.rs b/src/http/h1.rs index 7b9e4abd21..77a3dda96d 100644 --- a/src/http/h1.rs +++ b/src/http/h1.rs @@ -11,7 +11,7 @@ use url::Position as UrlPosition; use buffer::BufReader; use Error; -use header::{Headers, Host, ContentLength, TransferEncoding}; +use header::{Headers, ContentLength, TransferEncoding}; use header::Encoding::Chunked; use method::{Method}; use net::{NetworkConnector, NetworkStream}; @@ -91,6 +91,7 @@ impl Stream { /// An implementation of the `HttpMessage` trait for HTTP/1.1. #[derive(Debug)] pub struct Http11Message { + is_proxied: bool, method: Option, stream: Wrapper, } @@ -131,6 +132,7 @@ impl HttpMessage for Http11Message { io::ErrorKind::Other, ""))); let mut method = None; + let is_proxied = self.is_proxied; self.stream.map_in_place(|stream: Stream| -> Stream { let stream = match stream { Stream::Idle(stream) => stream, @@ -144,17 +146,10 @@ impl HttpMessage for Http11Message { let mut stream = BufWriter::new(stream); { - let uri = match head.headers.get::() { - Some(host) - if Some(&*host.hostname) == head.url.host_str() - && host.port == head.url.port_or_known_default() => { - &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery] - }, - _ => { - trace!("url and host header dont match, using absolute uri form"); - head.url.as_ref() - } - + let uri = if is_proxied { + head.url.as_ref() + } else { + &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery] }; let version = version::HttpVersion::Http11; @@ -365,6 +360,11 @@ impl HttpMessage for Http11Message { try!(self.get_mut().close(Shutdown::Both)); Ok(()) } + + #[inline] + fn set_proxied(&mut self, val: bool) { + self.is_proxied = val; + } } impl Http11Message { @@ -401,6 +401,7 @@ impl Http11Message { /// the peer. pub fn with_stream(stream: Box) -> Http11Message { Http11Message { + is_proxied: false, method: None, stream: Wrapper::new(Stream::new(stream)), } diff --git a/src/http/message.rs b/src/http/message.rs index 2cffbfdd86..d983fafa03 100644 --- a/src/http/message.rs +++ b/src/http/message.rs @@ -70,6 +70,11 @@ pub trait HttpMessage: Write + Read + Send + Any + Typeable + Debug { fn close_connection(&mut self) -> ::Result<()>; /// Returns whether the incoming message has a body. fn has_body(&self) -> bool; + /// Called when the Client wishes to use a Proxy. + fn set_proxied(&mut self, val: bool) { + // default implementation so as to not be a breaking change. + warn!("default set_proxied({:?})", val); + } } impl HttpMessage { diff --git a/src/mock.rs b/src/mock.rs index cb8005f4ec..ac70a5159e 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -12,7 +12,7 @@ use solicit::http::frame::{SettingsFrame, Frame}; use solicit::http::connection::{HttpConnection, EndStream, DataChunk}; use header::Headers; -use net::{NetworkStream, NetworkConnector}; +use net::{NetworkStream, NetworkConnector, SslClient}; #[derive(Clone, Debug)] pub struct MockStream { @@ -315,3 +315,13 @@ impl NetworkConnector for MockHttp2Connector { Ok(self.streams.borrow_mut().remove(0)) } } + +#[derive(Debug, Default)] +pub struct MockSsl; + +impl SslClient for MockSsl { + type Stream = T; + fn wrap_client(&self, stream: T, _host: &str) -> ::Result { + Ok(stream) + } +} diff --git a/src/net.rs b/src/net.rs index 6b4fadd43c..c64ce5964d 100644 --- a/src/net.rs +++ b/src/net.rs @@ -6,7 +6,7 @@ use std::net::{SocketAddr, ToSocketAddrs, TcpStream, TcpListener, Shutdown}; use std::mem; #[cfg(feature = "openssl")] -pub use self::openssl::Openssl; +pub use self::openssl::{Openssl, OpensslClient}; use std::time::Duration; @@ -423,22 +423,22 @@ pub trait Ssl { } /// An abstraction to allow any SSL implementation to be used with client-side HttpsStreams. -pub trait SslClient { +pub trait SslClient { /// The protected stream. type Stream: NetworkStream + Send + Clone; /// Wrap a client stream with SSL. - fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result; + fn wrap_client(&self, stream: T, host: &str) -> ::Result; } /// An abstraction to allow any SSL implementation to be used with server-side HttpsStreams. -pub trait SslServer { +pub trait SslServer { /// The protected stream. type Stream: NetworkStream + Send + Clone; /// Wrap a server stream with SSL. - fn wrap_server(&self, stream: HttpStream) -> ::Result; + fn wrap_server(&self, stream: T) -> ::Result; } -impl SslClient for S { +impl SslClient for S { type Stream = ::Stream; fn wrap_client(&self, stream: HttpStream, host: &str) -> ::Result { @@ -446,7 +446,7 @@ impl SslClient for S { } } -impl SslServer for S { +impl SslServer for S { type Stream = ::Stream; fn wrap_server(&self, stream: HttpStream) -> ::Result { @@ -566,28 +566,35 @@ impl NetworkListener for HttpsListener { /// A connector that can protect HTTP streams using SSL. #[derive(Debug, Default)] -pub struct HttpsConnector { - ssl: S +pub struct HttpsConnector { + ssl: S, + connector: C, +} + +impl HttpsConnector { + /// Create a new connector using the provided SSL implementation. + pub fn new(s: S) -> HttpsConnector { + HttpsConnector::with_connector(s, HttpConnector) + } } -impl HttpsConnector { +impl HttpsConnector { /// Create a new connector using the provided SSL implementation. - pub fn new(s: S) -> HttpsConnector { - HttpsConnector { ssl: s } + pub fn with_connector(s: S, connector: C) -> HttpsConnector { + HttpsConnector { ssl: s, connector: connector } } } -impl NetworkConnector for HttpsConnector { +impl> NetworkConnector for HttpsConnector { type Stream = HttpsStream; fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result { - let addr = &(host, port); + let stream = try!(self.connector.connect(host, port, "http")); if scheme == "https" { debug!("https scheme"); - let stream = HttpStream(try!(TcpStream::connect(addr))); self.ssl.wrap_client(stream, host).map(HttpsStream::Https) } else { - HttpConnector.connect(host, port, scheme).map(HttpsStream::Http) + Ok(HttpsStream::Http(stream)) } } } @@ -638,6 +645,31 @@ mod openssl { pub context: Arc } + /// A client-specific implementation of OpenSSL. + #[derive(Debug, Clone)] + pub struct OpensslClient(SslContext); + + impl Default for OpensslClient { + fn default() -> OpensslClient { + OpensslClient(SslContext::new(SslMethod::Sslv23).unwrap_or_else(|e| { + // if we cannot create a SslContext, that's because of a + // serious problem. just crash. + panic!("{}", e) + })) + } + } + + + impl super::SslClient for OpensslClient { + type Stream = SslStream; + + fn wrap_client(&self, stream: T, host: &str) -> ::Result { + let ssl = try!(Ssl::new(&self.0)); + try!(ssl.set_hostname(host)); + SslStream::connect(ssl, stream).map_err(From::from) + } + } + impl Default for Openssl { fn default() -> Openssl { Openssl { diff --git a/src/server/response.rs b/src/server/response.rs index 04c5718db9..05e160733c 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -12,7 +12,7 @@ use std::thread; use time::now_utc; use header; -use http::h1::{CR, LF, LINE_ENDING, HttpWriter}; +use http::h1::{LINE_ENDING, HttpWriter}; use http::h1::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter}; use status; use net::{Fresh, Streaming}; @@ -82,8 +82,7 @@ impl<'a, W: Any> Response<'a, W> { fn write_head(&mut self) -> io::Result { debug!("writing head: {:?} {:?}", self.version, self.status); - try!(write!(&mut self.body, "{} {}{}{}", self.version, self.status, - CR as char, LF as char)); + try!(write!(&mut self.body, "{} {}\r\n", self.version, self.status)); if !self.headers.has::() { self.headers.set(header::Date(header::HttpDate(now_utc())));