diff --git a/src/relay/tcprelay/proxy_stream.rs b/src/relay/tcprelay/proxy_stream.rs index 77e12a5555f0..990b3a4a11b9 100644 --- a/src/relay/tcprelay/proxy_stream.rs +++ b/src/relay/tcprelay/proxy_stream.rs @@ -5,14 +5,15 @@ use std::{ io::{self, Error}, net::SocketAddr, pin::Pin, - task::{Context as TaskContext, Poll}, + task::{self, Poll}, time::Duration, }; -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; +use futures::ready; use log::{debug, error, trace}; -use pin_project::pin_project; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use pin_project::{pin_project, project}; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use crate::{ config::{ConfigType, ServerAddr, ServerConfig}, @@ -22,20 +23,204 @@ use crate::{ use super::{connection::Connection, CryptoStream, STcpStream}; -/// Stream wrapper for both direct connections and proxied connections -#[allow(clippy::large_enum_variant)] +enum ProxiedConnectState { + Connected(Address), + Handshaking { buf: BytesMut, data_len: usize }, + Established, +} + #[pin_project] -pub enum ProxyStream { - Direct { - #[pin] - stream: STcpStream, - context: SharedContext, - }, - Proxied { - #[pin] - stream: CryptoStream, - context: SharedContext, - }, +struct ProxiedConnection { + #[pin] + stream: CryptoStream, + state: ProxiedConnectState, +} + +impl ProxiedConnection { + fn connected(stream: CryptoStream, addr: Address) -> ProxiedConnection { + ProxiedConnection { + stream, + state: ProxiedConnectState::Connected(addr), + } + } + + fn local_addr(&self) -> io::Result { + self.stream.get_ref().get_ref().local_addr() + } +} + +impl AsyncRead for ProxiedConnection { + #[project] + fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll> { + self.project().stream.poll_read(cx, buf) + } +} + +impl AsyncWrite for ProxiedConnection { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, data: &[u8]) -> Poll> { + loop { + let this = self.as_mut().project(); + + match this.state { + ProxiedConnectState::Connected(ref addr) => { + assert_ne!(data.len(), 0); + + // Send relay address to remote + // + // NOTE: `Address` handshake packets are very small in most cases, + // so + // 1. it will be sent with the IV/Nonce data (implemented inside `CryptoStream`). + // 2. concatenating target's Address with the first data buffer (#232) + // + // For lower latency, first packet should be sent back quickly, + // so TCP_NODELAY should be kept enabled until the first data packet is received. + let addr_len = addr.serialized_len(); + let mut buf = BytesMut::with_capacity(addr_len + data.len()); + addr.write_to_buf(&mut buf); + buf.extend_from_slice(data); + + trace!("sending handshake address {} with data {} bytes", addr, data.len()); + + // Fast path + // + // For CryptoStream (Stream and AEAD), poll_write will return Ready(..) until all data have been sent out + match this.stream.poll_write(cx, &buf) { + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Ok(n)) => { + buf.advance(n); + + let remaining = buf.remaining(); + if remaining < data.len() { + // Ok, written some data with Address + let written_len = data.len() - remaining; + + trace!( + "sent handshake address {} with {} bytes of data, data len {} bytes", + addr, + written_len, + data.len(), + ); + + self.state = ProxiedConnectState::Established; + return Poll::Ready(Ok(written_len)); + } + + // FALLTHROUGH + // Handshaking branch will try to poll_write again + self.state = ProxiedConnectState::Handshaking { + buf, + data_len: data.len(), + }; + } + Poll::Pending => { + // poll_write is not ready, let Handshaking branch try again later + self.state = ProxiedConnectState::Handshaking { + buf, + data_len: data.len(), + }; + + return Poll::Pending; + } + } + } + ProxiedConnectState::Handshaking { ref mut buf, data_len } => { + let data_len = *data_len; + + // Try to write at least addr_len size + let n = ready!(this.stream.poll_write(cx, buf))?; + buf.advance(n); + + let remaining = buf.remaining(); + if remaining < data_len { + // Ok, written some data with Address + let written_len = data_len - remaining; + + trace!( + "sent handshake address with {} bytes of data, data len {} bytes", + written_len, + data_len + ); + + self.state = ProxiedConnectState::Established; + return Poll::Ready(Ok(written_len)); + } + } + ProxiedConnectState::Established => { + break; + } + } + } + + self.project().stream.poll_write(cx, data) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_shutdown(cx) + } +} + +#[pin_project] +enum ProxyConnection { + Direct(#[pin] STcpStream), + Proxied(#[pin] ProxiedConnection), +} + +impl ProxyConnection { + /// Check if the underlying connection is proxied + fn is_proxied(&self) -> bool { + match *self { + ProxyConnection::Proxied { .. } => true, + _ => false, + } + } + + fn local_addr(&self) -> io::Result { + match *self { + ProxyConnection::Direct(ref stream) => stream.get_ref().local_addr(), + ProxyConnection::Proxied(ref stream) => stream.local_addr(), + } + } +} + +macro_rules! forward_call { + ($self:expr, $method:ident $(, $param:expr)*) => { + // #[project] + match $self.as_mut().project() { + // ProxyConnection::Direct(stream) => stream.$method($($param),*), + __ProxyConnectionProjection::Direct(stream) => stream.$method($($param),*), + // ProxyConnection::Proxied(stream) => stream.$method($($param),*), + __ProxyConnectionProjection::Proxied(stream) => stream.$method($($param),*), + } + }; +} + +impl AsyncRead for ProxyConnection { + #[project] + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll> { + forward_call!(self, poll_read, cx, buf) + } +} + +impl AsyncWrite for ProxyConnection { + #[project] + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + // let p = forward_call!(self, poll_write, cx, buf); + forward_call!(self, poll_write, cx, buf) + } + + #[project] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + forward_call!(self, poll_flush, cx) + } + + #[project] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + forward_call!(self, poll_shutdown, cx) + } } #[derive(Debug)] @@ -72,6 +257,14 @@ impl Display for ProxyStreamError { } } +/// Stream wrapper for both direct connections and proxied connections +#[pin_project] +pub struct ProxyStream { + #[pin] + connection: ProxyConnection, + context: SharedContext, +} + impl ProxyStream { /// Connect to remote by ACL rules pub async fn connect( @@ -105,9 +298,9 @@ impl ProxyStream { } }; - Ok(ProxyStream::Direct { - stream: Connection::new(stream, timeout), + Ok(ProxyStream { context, + connection: ProxyConnection::Direct(Connection::new(stream, timeout)), }) } @@ -134,11 +327,11 @@ impl ProxyStream { ); let server_stream = connect_proxy_server(&context, svr_cfg).await?; - let proxy_stream = proxy_server_handshake(context.clone(), server_stream, svr_cfg, addr).await?; + let proxy_stream = CryptoStream::new(context.clone(), server_stream, svr_cfg); - Ok(ProxyStream::Proxied { - stream: proxy_stream, + Ok(ProxyStream { context, + connection: ProxyConnection::Proxied(ProxiedConnection::connected(proxy_stream, addr.clone())), }) } @@ -161,41 +354,23 @@ impl ProxyStream { /// Returns the local socket address of this stream socket pub fn local_addr(&self) -> io::Result { - match *self { - ProxyStream::Direct { ref stream, .. } => stream.get_ref().local_addr(), - ProxyStream::Proxied { ref stream, .. } => stream.get_ref().get_ref().local_addr(), - } + self.connection.local_addr() } /// Check if the underlying connection is proxied pub fn is_proxied(&self) -> bool { - match *self { - ProxyStream::Proxied { .. } => true, - _ => false, - } + self.connection.is_proxied() } /// Get reference to context pub fn context(&self) -> &Context { - match *self { - ProxyStream::Direct { ref context, .. } => &context, - ProxyStream::Proxied { ref context, .. } => &context, - } + &self.context } } -macro_rules! forward_call { - ($self:expr, $method:ident $(, $param:expr)*) => { - match $self.as_mut().project() { - __ProxyStreamProjection::Direct { stream, .. } => stream.$method($($param),*), - __ProxyStreamProjection::Proxied { stream, .. } => stream.$method($($param),*), - } - }; -} - impl AsyncRead for ProxyStream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut [u8]) -> Poll> { - let p = forward_call!(self, poll_read, cx, buf); + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll> { + let p = self.as_mut().project().connection.poll_read(cx, buf); // Flow statistic for Android client #[cfg(feature = "local-flow-stat")] @@ -212,8 +387,8 @@ impl AsyncRead for ProxyStream { } impl AsyncWrite for ProxyStream { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { - let p = forward_call!(self, poll_write, cx, buf); + fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll> { + let p = self.as_mut().project().connection.poll_write(cx, buf); // Flow statistic for Android client #[cfg(feature = "local-flow-stat")] @@ -228,12 +403,12 @@ impl AsyncWrite for ProxyStream { p } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - forward_call!(self, poll_flush, cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().connection.poll_flush(cx) } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - forward_call!(self, poll_shutdown, cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().connection.poll_shutdown(cx) } } @@ -344,33 +519,3 @@ async fn connect_proxy_server(context: &Context, svr_cfg: &ServerConfig) -> io:: ); Err(last_err) } - -/// Handshake logic for ShadowSocks Client -async fn proxy_server_handshake( - context: SharedContext, - remote_stream: STcpStream, - svr_cfg: &ServerConfig, - relay_addr: &Address, -) -> io::Result> { - let mut stream = CryptoStream::new(context, remote_stream, svr_cfg); - - trace!("got encrypt stream and going to send addr: {:?}", relay_addr); - - // Send relay address to remote - // - // NOTE: `Address` handshake packets are very small in most cases, - // so it will be sent with the IV/Nonce data (implemented inside `CryptoStream`). - // - // For lower latency, first packet should be sent back quickly, - // so TCP_NODELAY should be kept enabled until the first data packet is received. - let mut addr_buf = BytesMut::with_capacity(relay_addr.serialized_len()); - relay_addr.write_to_buf(&mut addr_buf); - stream.write_all(&addr_buf).await?; - - // Here we should keep the TCP_NODELAY set until the first packet is received. - // https://github.com/shadowsocks/shadowsocks-libev/pull/746 - // - // Reset TCP_NODELAY after the first packet is received and sent back. - - Ok(stream) -}