Skip to content

Commit

Permalink
net: TcpStream write all (#111)
Browse files Browse the repository at this point in the history
Add an async write_all method for the TcpStream type.

Also a doc example and a rework of the examples/tcp_listener.rs
to act as a ping-pong server.

Fixes #110.
  • Loading branch information
FrankReh authored Sep 14, 2022
1 parent 587c9c3 commit 4026539
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
42 changes: 29 additions & 13 deletions examples/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use tokio_uring::net::TcpListener;
fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: SocketAddr = args[1].parse().unwrap();
let socket_addr = if args.len() <= 1 {
"127.0.0.1:0"
} else {
args[1].as_ref()
};
let socket_addr: SocketAddr = socket_addr.parse().unwrap();

tokio_uring::start(async {
let listener = TcpListener::bind(socket_addr).unwrap();
Expand All @@ -19,14 +20,29 @@ fn main() {
loop {
let (stream, socket_addr) = listener.accept().await.unwrap();
tokio_uring::spawn(async move {
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written to {}: {}", socket_addr, result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read from {}: {:?}", socket_addr, &buf[..read]);
// implement ping-pong loop

use tokio_uring::buf::IoBuf; // for slice()

println!("{} connected", socket_addr);
let mut n = 0;

let mut buf = vec![0u8; 4096];
loop {
let (result, nbuf) = stream.read(buf).await;
buf = nbuf;
let read = result.unwrap();
if read == 0 {
println!("{} closed, {} total ping-ponged", socket_addr, n);
break;
}

let (res, slice) = stream.write_all(buf.slice(..read)).await;
let _ = res.unwrap();
buf = slice.into_inner();
println!("{} all {} bytes ping-ponged", socket_addr, read);
n += read;
}
});
}
});
Expand Down
83 changes: 83 additions & 0 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,89 @@ impl TcpStream {
pub async fn write<T: IoBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
self.inner.write(buf).await
}

/// Attempts to write an entire buffer to the stream.
///
/// This method will continuously call [`write`] until there is no more data to be
/// written or an error is returned. This method will not return until the entire
/// buffer has been successfully written or an error has occurred.
///
/// If the buffer contains no data, this will never call [`write`].
///
/// # Errors
///
/// This function will return the first error that [`write`] returns.
///
/// # Examples
///
/// ```no_run
/// use std::net::SocketAddr;
/// use tokio_uring::net::TcpListener;
/// use tokio_uring::buf::IoBuf;
///
/// let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
///
/// tokio_uring::start(async {
/// let listener = TcpListener::bind(addr).unwrap();
///
/// println!("Listening on {}", listener.local_addr().unwrap());
///
/// loop {
/// let (stream, _) = listener.accept().await.unwrap();
/// tokio_uring::spawn(async move {
/// let mut n = 0;
/// let mut buf = vec![0u8; 4096];
/// loop {
/// let (result, nbuf) = stream.read(buf).await;
/// buf = nbuf;
/// let read = result.unwrap();
/// if read == 0 {
/// break;
/// }
///
/// let (res, slice) = stream.write_all(buf.slice(..read)).await;
/// let _ = res.unwrap();
/// buf = slice.into_inner();
/// n += read;
/// }
/// });
/// }
/// });
/// ```
///
/// [`write`]: Self::write
pub async fn write_all<T: IoBuf>(&self, mut buf: T) -> crate::BufResult<(), T> {
let mut n = 0;
while n < buf.bytes_init() {
let res = self.write(buf.slice(n..)).await;
match res {
(Ok(0), slice) => {
return (
Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)),
slice.into_inner(),
)
}
(Ok(m), slice) => {
n += m;
buf = slice.into_inner();
}

// This match on an EINTR error is not performed because this
// crate's design ensures we are not calling the 'wait' option
// in the ENTER syscall. Only an Enter with 'wait' can generate
// an EINTR according to the io_uring man pages.
// (Err(ref e), slice) if e.kind() == std::io::ErrorKind::Interrupted => {
// buf = slice.into_inner();
// },
(Err(e), slice) => return (Err(e), slice.into_inner()),
}
}

(Ok(()), buf)
}
}

impl AsRawFd for TcpStream {
Expand Down

0 comments on commit 4026539

Please sign in to comment.