Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frankreh/tcp write all #111

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions examples/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use tokio_uring::net::TcpListener;
fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: SocketAddr = args[1].parse().unwrap();
let socket_addr = if args.len() <= 1 {
"127.0.0.1:0"
} else {
args[1].as_ref()
};
let socket_addr: SocketAddr = socket_addr.parse().unwrap();

tokio_uring::start(async {
let listener = TcpListener::bind(socket_addr).unwrap();
Expand All @@ -19,14 +20,29 @@ fn main() {
loop {
let (stream, socket_addr) = listener.accept().await.unwrap();
tokio_uring::spawn(async move {
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written to {}: {}", socket_addr, result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read from {}: {:?}", socket_addr, &buf[..read]);
// implement ping-pong loop

use tokio_uring::buf::IoBuf; // for slice()

println!("{} connected", socket_addr);
let mut n = 0;

let mut buf = vec![0u8; 4096];
loop {
let (result, nbuf) = stream.read(buf).await;
buf = nbuf;
let read = result.unwrap();
if read == 0 {
println!("{} closed, {} total ping-ponged", socket_addr, n);
break;
}

let (res, slice) = stream.write_all(buf.slice(..read)).await;
let _ = res.unwrap();
buf = slice.into_inner();
println!("{} all {} bytes ping-ponged", socket_addr, read);
n += read;
}
});
}
});
Expand Down
83 changes: 83 additions & 0 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,89 @@ impl TcpStream {
pub async fn write<T: IoBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
self.inner.write(buf).await
}

/// Attempts to write an entire buffer to the stream.
///
/// This method will continuously call [`write`] until there is no more data to be
/// written or an error is returned. This method will not return until the entire
/// buffer has been successfully written or an error has occurred.
///
/// If the buffer contains no data, this will never call [`write`].
///
/// # Errors
///
/// This function will return the first error that [`write`] returns.
///
/// # Examples
///
/// ```no_run
/// use std::net::SocketAddr;
/// use tokio_uring::net::TcpListener;
/// use tokio_uring::buf::IoBuf;
///
/// let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
///
/// tokio_uring::start(async {
/// let listener = TcpListener::bind(addr).unwrap();
///
/// println!("Listening on {}", listener.local_addr().unwrap());
///
/// loop {
/// let (stream, _) = listener.accept().await.unwrap();
/// tokio_uring::spawn(async move {
/// let mut n = 0;
/// let mut buf = vec![0u8; 4096];
/// loop {
/// let (result, nbuf) = stream.read(buf).await;
/// buf = nbuf;
/// let read = result.unwrap();
/// if read == 0 {
/// break;
/// }
///
/// let (res, slice) = stream.write_all(buf.slice(..read)).await;
/// let _ = res.unwrap();
/// buf = slice.into_inner();
/// n += read;
/// }
/// });
/// }
/// });
/// ```
///
/// [`write`]: Self::write
pub async fn write_all<T: IoBuf>(&self, mut buf: T) -> crate::BufResult<(), T> {
let mut n = 0;
while n < buf.bytes_init() {
let res = self.write(buf.slice(n..)).await;
match res {
(Ok(0), slice) => {
return (
Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)),
slice.into_inner(),
)
}
(Ok(m), slice) => {
n += m;
buf = slice.into_inner();
}

// This match on an EINTR error is not performed because this
// crate's design ensures we are not calling the 'wait' option
// in the ENTER syscall. Only an Enter with 'wait' can generate
// an EINTR according to the io_uring man pages.
// (Err(ref e), slice) if e.kind() == std::io::ErrorKind::Interrupted => {
// buf = slice.into_inner();
// },
(Err(e), slice) => return (Err(e), slice.into_inner()),
}
}

(Ok(()), buf)
}
}

impl AsRawFd for TcpStream {
Expand Down