Skip to content

Commit

Permalink
Aggressively terminate half-closed connections
Browse files Browse the repository at this point in the history
Previously, if the client closed after receiving a reply, a keepalive connection
to the server would stick around until the timeout, even though we will never
reuse it.
  • Loading branch information
mqudsi committed Jun 30, 2022
1 parent b2c2876 commit 0164ef8
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ edition = "2018"
futures = "0.3.21"
getopts = "0.2.21"
rand = "0.8.5"
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", ] }
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", "sync" ] }
# trust-dns-resolver = "0.19.5"
166 changes: 108 additions & 58 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use getopts::Options;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::join;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;

type BoxedError = Box<dyn std::error::Error + Sync + Send + 'static>;
static DEBUG: AtomicBool = AtomicBool::new(false);

fn print_usage(program: &str, opts: Options) {
let program_path = std::path::PathBuf::from(program);
let program_name = program_path.file_stem().unwrap().to_string_lossy();
let brief = format!("Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name);
let brief = format!(
"Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name
);
print!("{}", opts.usage(&brief));
}

Expand All @@ -21,14 +24,18 @@ async fn main() -> Result<(), BoxedError> {
let program = args[0].clone();

let mut opts = Options::new();
opts.optopt("b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR");
opts.optopt("l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT");
opts.optopt(
"b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR",
);
opts.optopt(
"l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT",
);
opts.optflag("d", "debug", "Enable debug mode");

let matches = match opts.parse(&args[1..]) {
Expand All @@ -44,7 +51,7 @@ async fn main() -> Result<(), BoxedError> {
_ => {
print_usage(&program, opts);
std::process::exit(-1);
},
}
};

if !remote.contains(':') {
Expand All @@ -71,65 +78,108 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
} else {
format!("{}:{}", bind_ip, local_port)
};
let bind_sock = bind_addr.parse::<std::net::SocketAddr>().expect("Failed to parse bind address");
let bind_sock = bind_addr
.parse::<std::net::SocketAddr>()
.expect("Failed to parse bind address");
let listener = TcpListener::bind(&bind_sock).await?;
println!("Listening on {}", listener.local_addr().unwrap());

// We have either been provided an IP address or a host name.
// Instead of trying to check its format, just trying creating a SocketAddr from it.
// let parse_result = remote.parse::<SocketAddr>();
let remote = std::sync::Arc::new(remote.to_string());

async fn copy_with_abort<R, W>(
read: &mut R,
write: &mut W,
cancel: &broadcast::Sender<()>,
) -> tokio::io::Result<usize>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut abort = cancel.subscribe();
let mut copied = 0;
let mut buf = [0u8; 1024];
loop {
let bytes_read;
tokio::select! {
biased;

result = read.read(&mut buf) => {
bytes_read = result?;
},
_ = abort.recv() => {
return Ok(copied);
}
}

if bytes_read == 0 {
break;
}

write.write_all(&buf[0..bytes_read]).await?;
copied += bytes_read;
}

let _ = cancel.send(());
Ok(copied)
}

loop {
let remote = remote.clone();
let (mut client, client_addr) = listener.accept().await?;

tokio::spawn(async move {
println!("New connection from {}", client_addr);

// Establish connection to upstream for each incoming client connection
let mut remote = TcpStream::connect(remote.as_str()).await?;
let (mut client_recv, mut client_send) = client.split();
let (mut remote_recv, mut remote_send) = remote.split();

// This version of the join! macro does not require that the futures are fused and
// pinned prior to passing to join.
let (remote_bytes_copied, client_bytes_copied) = join!(
tokio::io::copy(&mut remote_recv, &mut client_send),
tokio::io::copy(&mut client_recv, &mut remote_send),
);

match remote_bytes_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!("Transferred {} bytes from remote client {} to upstream server",
count, client_addr);
}

}
Err(err) => {
eprintln!("Error writing from remote client {} to upstream server!",
client_addr);
eprintln!("{:?}", err);
println!("New connection from {}", client_addr);

// Establish connection to upstream for each incoming client connection
let mut remote = TcpStream::connect(remote.as_str()).await?;
let (mut client_read, mut client_write) = client.split();
let (mut remote_read, mut remote_write) = remote.split();

let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, &cancel),
copy_with_abort(&mut client_read, &mut remote_write, &cancel),
};

match client_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from remote client {} to upstream server",
count, client_addr
);
}
};

match client_bytes_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!("Transferred {} bytes from upstream server to remote client {}",
count, client_addr);
}
}
Err(err) => {
eprintln!(
"Error writing bytes from remote client {} to upstream server",
client_addr
);
eprintln!("{}", err);
}
};

match remote_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from upstream server to remote client {}",
count, client_addr
);
}
Err(err) => {
eprintln!("Error writing bytes from upstream server to remote client {}",
client_addr);
eprintln!("{:?}", err);
}
};

let r: Result<(), BoxedError> = Ok(());
r
}
Err(err) => {
eprintln!(
"Error writing from upstream server to remote client {}!",
client_addr
);
eprintln!("{}", err);
}
};

let r: Result<(), BoxedError> = Ok(());
r
});
}
}

0 comments on commit 0164ef8

Please sign in to comment.