Skip to content

Commit

Permalink
[#232] Send data along with header (handshake)
Browse files Browse the repository at this point in the history
  • Loading branch information
zonyitoo committed May 3, 2020
1 parent a3cbc57 commit 45a3469
Showing 1 changed file with 226 additions and 81 deletions.
307 changes: 226 additions & 81 deletions src/relay/tcprelay/proxy_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use std::{
io::{self, Error},
net::SocketAddr,
pin::Pin,
task::{Context as TaskContext, Poll},
task::{self, Poll},
time::Duration,
};

use bytes::BytesMut;
use bytes::{Buf, BytesMut};
use futures::ready;
use log::{debug, error, trace};
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use pin_project::{pin_project, project};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};

use crate::{
config::{ConfigType, ServerAddr, ServerConfig},
Expand All @@ -22,20 +23,204 @@ use crate::{

use super::{connection::Connection, CryptoStream, STcpStream};

/// Stream wrapper for both direct connections and proxied connections
#[allow(clippy::large_enum_variant)]
enum ProxiedConnectState {
Connected(Address),
Handshaking { buf: BytesMut, data_len: usize },
Established,
}

#[pin_project]
pub enum ProxyStream {
Direct {
#[pin]
stream: STcpStream,
context: SharedContext,
},
Proxied {
#[pin]
stream: CryptoStream<STcpStream>,
context: SharedContext,
},
struct ProxiedConnection {
#[pin]
stream: CryptoStream<STcpStream>,
state: ProxiedConnectState,
}

impl ProxiedConnection {
fn connected(stream: CryptoStream<STcpStream>, addr: Address) -> ProxiedConnection {
ProxiedConnection {
stream,
state: ProxiedConnectState::Connected(addr),
}
}

fn local_addr(&self) -> io::Result<SocketAddr> {
self.stream.get_ref().get_ref().local_addr()
}
}

impl AsyncRead for ProxiedConnection {
#[project]
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.project().stream.poll_read(cx, buf)
}
}

impl AsyncWrite for ProxiedConnection {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, data: &[u8]) -> Poll<io::Result<usize>> {
loop {
let this = self.as_mut().project();

match this.state {
ProxiedConnectState::Connected(ref addr) => {
assert_ne!(data.len(), 0);

// Send relay address to remote
//
// NOTE: `Address` handshake packets are very small in most cases,
// so
// 1. it will be sent with the IV/Nonce data (implemented inside `CryptoStream`).
// 2. concatenating target's Address with the first data buffer (#232)
//
// For lower latency, first packet should be sent back quickly,
// so TCP_NODELAY should be kept enabled until the first data packet is received.
let addr_len = addr.serialized_len();
let mut buf = BytesMut::with_capacity(addr_len + data.len());
addr.write_to_buf(&mut buf);
buf.extend_from_slice(data);

trace!("sending handshake address {} with data {} bytes", addr, data.len());

// Fast path
//
// For CryptoStream (Stream and AEAD), poll_write will return Ready(..) until all data have been sent out
match this.stream.poll_write(cx, &buf) {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Ok(n)) => {
buf.advance(n);

let remaining = buf.remaining();
if remaining < data.len() {
// Ok, written some data with Address
let written_len = data.len() - remaining;

trace!(
"sent handshake address {} with {} bytes of data, data len {} bytes",
addr,
written_len,
data.len(),
);

self.state = ProxiedConnectState::Established;
return Poll::Ready(Ok(written_len));
}

// FALLTHROUGH
// Handshaking branch will try to poll_write again
self.state = ProxiedConnectState::Handshaking {
buf,
data_len: data.len(),
};
}
Poll::Pending => {
// poll_write is not ready, let Handshaking branch try again later
self.state = ProxiedConnectState::Handshaking {
buf,
data_len: data.len(),
};

return Poll::Pending;
}
}
}
ProxiedConnectState::Handshaking { ref mut buf, data_len } => {
let data_len = *data_len;

// Try to write at least addr_len size
let n = ready!(this.stream.poll_write(cx, buf))?;
buf.advance(n);

let remaining = buf.remaining();
if remaining < data_len {
// Ok, written some data with Address
let written_len = data_len - remaining;

trace!(
"sent handshake address with {} bytes of data, data len {} bytes",
written_len,
data_len
);

self.state = ProxiedConnectState::Established;
return Poll::Ready(Ok(written_len));
}
}
ProxiedConnectState::Established => {
break;
}
}
}

self.project().stream.poll_write(cx, data)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().stream.poll_shutdown(cx)
}
}

#[pin_project]
enum ProxyConnection {
Direct(#[pin] STcpStream),
Proxied(#[pin] ProxiedConnection),
}

impl ProxyConnection {
/// Check if the underlying connection is proxied
fn is_proxied(&self) -> bool {
match *self {
ProxyConnection::Proxied { .. } => true,
_ => false,
}
}

fn local_addr(&self) -> io::Result<SocketAddr> {
match *self {
ProxyConnection::Direct(ref stream) => stream.get_ref().local_addr(),
ProxyConnection::Proxied(ref stream) => stream.local_addr(),
}
}
}

macro_rules! forward_call {
($self:expr, $method:ident $(, $param:expr)*) => {
// #[project]
match $self.as_mut().project() {
// ProxyConnection::Direct(stream) => stream.$method($($param),*),
__ProxyConnectionProjection::Direct(stream) => stream.$method($($param),*),
// ProxyConnection::Proxied(stream) => stream.$method($($param),*),
__ProxyConnectionProjection::Proxied(stream) => stream.$method($($param),*),
}
};
}

impl AsyncRead for ProxyConnection {
#[project]
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
forward_call!(self, poll_read, cx, buf)
}
}

impl AsyncWrite for ProxyConnection {
#[project]
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
// let p = forward_call!(self, poll_write, cx, buf);
forward_call!(self, poll_write, cx, buf)
}

#[project]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
forward_call!(self, poll_flush, cx)
}

#[project]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
forward_call!(self, poll_shutdown, cx)
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -72,6 +257,14 @@ impl Display for ProxyStreamError {
}
}

/// Stream wrapper for both direct connections and proxied connections
#[pin_project]
pub struct ProxyStream {
#[pin]
connection: ProxyConnection,
context: SharedContext,
}

impl ProxyStream {
/// Connect to remote by ACL rules
pub async fn connect(
Expand Down Expand Up @@ -105,9 +298,9 @@ impl ProxyStream {
}
};

Ok(ProxyStream::Direct {
stream: Connection::new(stream, timeout),
Ok(ProxyStream {
context,
connection: ProxyConnection::Direct(Connection::new(stream, timeout)),
})
}

Expand All @@ -134,11 +327,11 @@ impl ProxyStream {
);

let server_stream = connect_proxy_server(&context, svr_cfg).await?;
let proxy_stream = proxy_server_handshake(context.clone(), server_stream, svr_cfg, addr).await?;
let proxy_stream = CryptoStream::new(context.clone(), server_stream, svr_cfg);

Ok(ProxyStream::Proxied {
stream: proxy_stream,
Ok(ProxyStream {
context,
connection: ProxyConnection::Proxied(ProxiedConnection::connected(proxy_stream, addr.clone())),
})
}

Expand All @@ -161,41 +354,23 @@ impl ProxyStream {

/// Returns the local socket address of this stream socket
pub fn local_addr(&self) -> io::Result<SocketAddr> {
match *self {
ProxyStream::Direct { ref stream, .. } => stream.get_ref().local_addr(),
ProxyStream::Proxied { ref stream, .. } => stream.get_ref().get_ref().local_addr(),
}
self.connection.local_addr()
}

/// Check if the underlying connection is proxied
pub fn is_proxied(&self) -> bool {
match *self {
ProxyStream::Proxied { .. } => true,
_ => false,
}
self.connection.is_proxied()
}

/// Get reference to context
pub fn context(&self) -> &Context {
match *self {
ProxyStream::Direct { ref context, .. } => &context,
ProxyStream::Proxied { ref context, .. } => &context,
}
&self.context
}
}

macro_rules! forward_call {
($self:expr, $method:ident $(, $param:expr)*) => {
match $self.as_mut().project() {
__ProxyStreamProjection::Direct { stream, .. } => stream.$method($($param),*),
__ProxyStreamProjection::Proxied { stream, .. } => stream.$method($($param),*),
}
};
}

impl AsyncRead for ProxyStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let p = forward_call!(self, poll_read, cx, buf);
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let p = self.as_mut().project().connection.poll_read(cx, buf);

// Flow statistic for Android client
#[cfg(feature = "local-flow-stat")]
Expand All @@ -212,8 +387,8 @@ impl AsyncRead for ProxyStream {
}

impl AsyncWrite for ProxyStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let p = forward_call!(self, poll_write, cx, buf);
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let p = self.as_mut().project().connection.poll_write(cx, buf);

// Flow statistic for Android client
#[cfg(feature = "local-flow-stat")]
Expand All @@ -228,12 +403,12 @@ impl AsyncWrite for ProxyStream {
p
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<io::Result<()>> {
forward_call!(self, poll_flush, cx)
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().connection.poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<io::Result<()>> {
forward_call!(self, poll_shutdown, cx)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().connection.poll_shutdown(cx)
}
}

Expand Down Expand Up @@ -344,33 +519,3 @@ async fn connect_proxy_server(context: &Context, svr_cfg: &ServerConfig) -> io::
);
Err(last_err)
}

/// Handshake logic for ShadowSocks Client
async fn proxy_server_handshake(
context: SharedContext,
remote_stream: STcpStream,
svr_cfg: &ServerConfig,
relay_addr: &Address,
) -> io::Result<CryptoStream<STcpStream>> {
let mut stream = CryptoStream::new(context, remote_stream, svr_cfg);

trace!("got encrypt stream and going to send addr: {:?}", relay_addr);

// Send relay address to remote
//
// NOTE: `Address` handshake packets are very small in most cases,
// so it will be sent with the IV/Nonce data (implemented inside `CryptoStream`).
//
// For lower latency, first packet should be sent back quickly,
// so TCP_NODELAY should be kept enabled until the first data packet is received.
let mut addr_buf = BytesMut::with_capacity(relay_addr.serialized_len());
relay_addr.write_to_buf(&mut addr_buf);
stream.write_all(&addr_buf).await?;

// Here we should keep the TCP_NODELAY set until the first packet is received.
// https://github.com/shadowsocks/shadowsocks-libev/pull/746
//
// Reset TCP_NODELAY after the first packet is received and sent back.

Ok(stream)
}

0 comments on commit 45a3469

Please sign in to comment.